diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f1f9a781a07..712f35400e4a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -924,7 +924,14 @@ if(VLLM_GPU_LANG STREQUAL "HIP") set(VLLM_ROCM_EXT_SRC "csrc/rocm/torch_bindings.cpp" "csrc/rocm/skinny_gemms.cu" - "csrc/rocm/attention.cu") + "csrc/rocm/attention.cu" + "csrc/rocm/ck_extensions/ck_tile_gemm_bf16/gemm_bf16.cu") + + set(CK_TILE_INCLUDE_DIR + "csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include" + "csrc/rocm/ck_extensions/include/" + "opt/rocm/include" + ) define_gpu_extension_target( _rocm_C @@ -933,6 +940,7 @@ if(VLLM_GPU_LANG STREQUAL "HIP") SOURCES ${VLLM_ROCM_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CK_TILE_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) endif() diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/gemm_bf16.cu b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/gemm_bf16.cu new file mode 100755 index 000000000000..999f98baba58 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/gemm_bf16.cu @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +// #include +// #include +// #include +#include +#include +#include +#include +#include +#include +#include "core/registration.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm_utils.hpp" +#include "gemm_bool_switch.hpp" +#include "gemm_setting.hpp" +#include "tile_gemm_shape_bias.hpp" +#include "gemm_bias_kernel.hpp" + +//#include "run_gemm_example.inc" + +template +float gemm_calc(const ck_tile::GemmHostArgs_bias& args, const ck_tile::stream_config& s); + +template +float gemm_calc_dipatch(const ck_tile::GemmHostArgs_bias& args, const ck_tile::stream_config& s); + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +torch::Tensor ck_tile_gemm_bf16( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& bias, + torch::Tensor& Y + ) +{ + + TORCH_CHECK(XQ.dtype() == WQ.dtype(), "Weights and activations should have the same dtype!"); + + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + using ALayout = Row; + using BLayout = Col; + using CLayout = Row; + using ELayout = Row; + + // printf("1\n"); + ck_tile::index_t M = XQ.size(0); + ck_tile::index_t N = WQ.size(0); + ck_tile::index_t K = XQ.size(1); + ck_tile::index_t kbatch = 1; + int n_warmup = 0; + int n_repeat = 1; + bool persistent = 0; + ck_tile::index_t stride_A = 0; + ck_tile::index_t stride_B = 0; + ck_tile::index_t stride_C = 0; + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(ALayout{})); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(BLayout{})); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using CDataType = ck_tile::bf16_t; + + using AccDataType = typename GemmTypeConfig::AccDataType; + + + ck_tile::GemmHostArgs_bias args = {XQ.data_ptr(), + WQ.data_ptr(), + Y.data_ptr(), + bias.data_ptr(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + stride_C}; + + + + + float ave_time; + const at::cuda::OptionalCUDAGuard device_guard(device_of(XQ)); + at::cuda::getCurrentCUDAStream().stream(); + const cudaStream_t hip_stream = at::cuda::getCurrentCUDAStream(); + + auto config = ck_tile::stream_config{hip_stream, false, 0, n_warmup, n_repeat, false, false, 1}; + + if(persistent) + { + ave_time = gemm_calc( + args, config); + } + else + { + ave_time = gemm_calc( + args, config); + } + return Y; +} + + +template +float gemm_calc(const ck_tile::GemmHostArgs_bias& args, const ck_tile::stream_config& s) +{ + float ave_time=0.0; + KM_KN_KK_SWITCH( + args.M, + kM, + args.N, + kN, + args.K, + MaxK, + [&]{ + ave_time = gemm_calc_dipatch(args, s); + }); + return ave_time; +} + +template +float gemm_calc_dipatch(const ck_tile::GemmHostArgs_bias& args, const ck_tile::stream_config& s) +{ + using GemmTileShapeConfig_ = GemmTileShapeConfig; + using GemmConfig = GemmConfigPreshuffle_2; + //using GemmConfig = GemmConfigComputeV3; + + // GemmConfigPreshufle_2 GemmConfig; + const bool pad_M = !(args.M % GemmTileShapeConfig_::M_Tile == 0); + const bool pad_N = !(args.N % GemmTileShapeConfig_::N_Tile == 0); + const bool pad_K = !(args.K % GemmTileShapeConfig_::K_Tile == 0); + float ave_time{0}; + + BOOL_SWITCH_3( + pad_M, + padM_, + pad_N, + padN_, + pad_K, + padK_, + [&]{ + using GemmShape = ck_tile::TileGemmShape_bias< + ck_tile::sequence, + ck_tile::sequence, + ck_tile::sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + UniversalGemmProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmTileShapeConfig_::M_Warp, + GemmTileShapeConfig_::N_Warp, + GemmTileShapeConfig_::M_Warp_Tile, + GemmTileShapeConfig_::N_Warp_Tile, + GemmTileShapeConfig_::K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation, + GemmConfig::NumWaveGroups>>; + using Kernel = ck_tile::GemmKernel_bias; + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + // rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + }); + return ave_time; +} + diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/ck_tile_gemm_bf16.h b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/ck_tile_gemm_bf16.h new file mode 100755 index 000000000000..4dcdea46b44f --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/ck_tile_gemm_bf16.h @@ -0,0 +1,10 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +#include +#include +torch::Tensor ck_tile_gemm_bf16( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &bias, + torch::Tensor &Y); diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_kernel.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_kernel.hpp new file mode 100755 index 000000000000..37e79f609484 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_kernel.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" +#include "ck_tile/core/utility/type_traits.hpp" +#include "universal_gemm_bias_kernel.hpp" +namespace ck_tile { + +/// @brief The GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernel "GemmKernel" when creating kernel arguments +/// object. It contain all necessary information required to build proper kernel argument +/// and launch kernel on GPU. +/// This structure defines the GEMM problem configuration by stating all required information +/// like M,N,K sizes and respective strides. +struct GemmHostArgs_bias +{ + CK_TILE_HOST GemmHostArgs_bias() = default; + CK_TILE_HOST GemmHostArgs_bias(const void* a_ptr_, + const void* b_ptr_, + void* e_ptr_, + const void* bias_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + index_t stride_A_, + index_t stride_B_, + index_t stride_E_) + : a_ptr(a_ptr_), + b_ptr(b_ptr_), + e_ptr(e_ptr_), + bias_ptr(bias_ptr_), + M(M_), + N(N_), + K(K_), + stride_A(stride_A_), + stride_B(stride_B_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const void* a_ptr; + const void* b_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + const void* bias_ptr; + index_t M; + index_t N; + index_t K; + index_t stride_A; + index_t stride_B; + + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +template +struct GemmKernel_bias +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel_bias = + UniversalGemmKernel_bias; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, E and D + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, E and D + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "ALayout and ADataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief BLayout and BDataType are expected to be scalars, not a tuple. + static_assert( + !is_detected::value && !is_detected::value, + "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); + + /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "C/ELayout and C/EDataType must be scalars."); + + static constexpr index_t NumATensor = 1; + static constexpr index_t NumBTensor = 1; + + CK_TILE_HOST static auto GetName() -> const std::string + { + return UniversalGemmKernel_bias::GetName(); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 + { + return UniversalGemmKernel_bias::GridSize(M, N, KBatch); + } + + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel_bias::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel_bias::BlockSize(); + } + + CK_TILE_HOST static constexpr auto MakeKernelArgs(const GemmHostArgs_bias& hostArgs) -> + typename UniversalGemmKernel_bias::KernelArgs + { + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B. + return UniversalGemmKernel_bias::MakeKernelArgs( + UniversalGemmHostArgs_bias( + {hostArgs.a_ptr}, + {hostArgs.b_ptr}, + {/*hostArgs.ds_ptr*/}, + hostArgs.e_ptr, + hostArgs.bias_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + {hostArgs.stride_A}, + {hostArgs.stride_B}, + {/*hostArgs.stride_Ds*/}, + hostArgs.stride_E)); + } + + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel_bias::KernelArgs& kargs) -> bool + { + return UniversalGemmKernel_bias::IsSupportedArgument(kargs); + } + + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel_bias::KernelArgs kargs) const -> void + { + UniversalGemmKernel_bias{}.template operator()(kargs); + } +}; +} // namespace ck_tile diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_pipeline_ag_bg_cr_comp_v3.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_pipeline_ag_bg_cr_comp_v3.hpp new file mode 100755 index 000000000000..d593d74ba5a9 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_pipeline_ag_bg_cr_comp_v3.hpp @@ -0,0 +1,702 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "gemm_bias_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BaseGemmPipelineAgBgCrCompV3_bias +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + if(BlockHasHotloop(num_loop)) + { + return TailNumber::Full; + } + else + { + if(num_loop == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number) + { + // Handle all the valid cases. + if(has_hot_loop) + { + if(tail_number == TailNumber::Full) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } + else + { + if(tail_number == TailNumber::Odd) + { + return run_func(bool_constant{}, + integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + return run_func(bool_constant{}, + integral_constant{}); + } + } +#if defined(__HIP_DEVICE_COMPILE__) + // This path should be unreachable in device code if tail_number is valid. + __builtin_unreachable(); +#else + // If execution reaches here, it's an invalid combination of arguments. + if(has_hot_loop) + { + throw std::logic_error("Invalid TailNumber: If has_hot_loop is true, tail_number must " + "be TailNumber::Full."); + } + else + { + throw std::logic_error("Invalid TailNumber: If has_hot_loop is false, tail_number must " + "be TailNumber::Odd or TailNumber::Even."); + } +#endif + } +}; + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 +template +struct GemmPipelineAgBgCrCompV3_bias : public BaseGemmPipelineAgBgCrCompV3_bias +{ + using Base = BaseGemmPipelineAgBgCrCompV3_bias; + using PipelineImplBase = GemmPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockGemm = remove_cvref_t())>; + using I0 = number<0>; + using I1 = number<1>; + using I2 = number<2>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t MPerBlock = BlockGemmShape::kM; + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA(); } + static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB(); } + static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC(); } + + static constexpr index_t APackedSize = + ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } + static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr index_t Preshuffle = Problem::Preshuffle; + + static constexpr bool HasHotLoop = + Problem::HasHotLoop; // Base::BlockHasHotloop(Problem::num_loop); + static constexpr auto TailNum = + Problem::TailNum; // Base::GetBlockLoopTailNum(Problem::num_loop); + static constexpr auto Scheduler = Problem::Scheduler; + + static constexpr auto is_a_load_tr_v = bool_constant{}; + static constexpr auto is_b_load_tr_v = bool_constant{}; + + using Base::PrefetchStages; + using Base::UsePersistentKernel; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + return concat('_', "pipeline_AgBgCrCompV3", + concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize, + concat('x', WaveNumM, WaveNumN), + concat('x', kPadM, kPadN, kPadK)); + // clang-format on + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + CK_TILE_HOST static std::string Print() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + auto str = std::stringstream{}; + + str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n" + << "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n" + << "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num + << "\n" + << "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num + << "\n" + << "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n" + << "C MFMA inst: " << C_MFMA_Inst_Num << "\n" + << "KPack: " << BlockGemm::Traits::KPack << "\n" + << "PrefetchStages: " << PrefetchStages << "\n"; + return str.str(); + } + + template + struct PipelineImpl : public PipelineImplBase + { + }; + + template <> + struct PipelineImpl : public PipelineImplBase + { + using Base = PipelineImplBase; + + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() + { + constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; + constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN; + constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK; + + constexpr index_t WaveSize = 64; + constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); + constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); + + // Below should be equal to AK1|BK1 + constexpr index_t A_LDS_Read_Width = GetSmemPackA(); + constexpr index_t B_LDS_Read_Width = GetSmemPackB(); + + constexpr index_t A_LDS_Write_Width = GetSmemPackA(); + constexpr index_t B_LDS_Write_Width = GetSmemPackB(); + + constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); + constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); + + constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width); + constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width); + + constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width); + constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width); + + constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerXDL * NPerXDL * KPerXDL); + + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num + : A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num + : B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / + // sizeof(BDataType) + // ? sizeof(ComputeDataType) / + // sizeof(ADataType) : sizeof(ComputeDataType) + // / sizeof(BDataType); + constexpr auto num_mfma_stage1 = + num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "A/B Dram block window should have the same data type as appropriate " + "([A|B]DataType) defined in Problem definition!"); + + constexpr bool is_a_col_major = + std::is_same_v; + constexpr bool is_b_row_major = std::is_same_v; + + static_assert(is_a_col_major + ? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "A block window has incorrect lengths for defined ALayout!"); + static_assert(is_b_row_major + ? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]) + : (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] && + KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), + "B block window has incorrect lengths for defined BLayout!"); + + // ------------------------------------------------------------------------------------ + // Definitions of all needed tiles + + // A/B tiles in LDS + auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); + + // Tile distribution for load from lds + constexpr auto a_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = + make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + + // A DRAM tile window for load + // A LDS tile window for store + // A LDS tile for block GEMM + auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] = + Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr); + + // B DRAM tile window for load + // B LDS tile window for store + // B LDS tile for block GEMM + auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] = + Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr); + + // Block GEMM + auto block_gemm = BlockGemm(); + + + + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + bias_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBiasDramTileDistribution()); + (void)bias_dram_window; + auto c_block_tile = block_gemm.MakeCBlockTile(); + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = + decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = + decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + + constexpr ADramTileWindowStep a_dram_tile_window_step = + is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = + is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock); + + // ----------------------------------------------------------------------------------------- + // Gemm pipeline start + + // prefetch + // global read 0 + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + auto bias = load_tile(bias_dram_window); + + sweep_tile(c_block_tile, [&](auto idx) { + // compute x = bias + x + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + c_block_tile(idx) += type_convert(bias[j_idx]) ; + + }); + + // LDS write 0 + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasHotLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + if constexpr(is_a_col_major && !is_a_load_tr_v()) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major && !is_b_load_tr_v()) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + + block_sync_lds(); + + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) + { + // Leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, a_block_tile); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + } + else + { + Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + } + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + } + else + { + Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + } + block_sync_lds(); + block_gemm.LocalPrefetch( + a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + } + // __builtin_amdgcn_sched_barrier(0); + // constexpr auto s0acc_spans = decltype(c_block_tile)::get_distributed_spans(); + // sweep_tile_span(s0acc_spans[number<0>{}], [&](auto idx0) { + // sweep_tile_span(s0acc_spans[number<1>{}], [&](auto idx1) { + // constexpr auto i_j_idx = make_tuple(idx0, idx1); + // const auto val = c_block_tile[i_j_idx]; + + // printf("ck_:::: y = %.6f\n", static_cast(val)); + // // s0_acc(i_j_idx) = (val > 0) ? val : val * lrelu_alpha; + // }); + // }); + + return c_block_tile; + + } + }; + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + a_element_func, + b_dram_block_window_tmp, + b_element_func, + num_loop, + p_smem); + } + + /** + * @brief This function runs the pipeline by wrapping it with the tail handler. + * + * @note This is used by the persistent gemm kernel variants that don't determine + * hot loop and tail number on the host side, e.g. grouped gemm kernel. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { + constexpr bool hot_loop = hot_loop_.value; + constexpr auto tail_num = tail_num_.value; + constexpr auto PassThrough = [](const auto& x) { return x; }; + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + PassThrough, + b_dram_block_window_tmp, + PassThrough, + num_loop, + p_smem); + }; + return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); + } + + /** + * @brief This function runs the pipeline using compile-time known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return PipelineImpl{}.template operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_dram_block_window_tmp, + [](const BDataType& b) { return b; }, + bias_dram_block_window_tmp, + num_loop, + p_smem); + } +}; + +} // namespace ck_tile diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_universal_pipeline_ag_bg_cr_policy.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_universal_pipeline_ag_bg_cr_policy.hpp new file mode 100755 index 000000000000..cc81a915f854 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bias_universal_pipeline_ag_bg_cr_policy.hpp @@ -0,0 +1,702 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +namespace ck_tile { + +template +struct UniversalGemmBasePolicy_bias +{ +#if defined(__gfx950__) + template + static constexpr bool is_a_load_tr = + std::is_same_v, tensor_layout::gemm::ColumnMajor>; + template + static constexpr bool is_b_load_tr = + std::is_same_v, tensor_layout::gemm::RowMajor>; +#else + template + static constexpr bool is_a_load_tr = false; + template + static constexpr bool is_b_load_tr = false; +#endif + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked; + static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + + using ADataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(is_a_load_tr) + { + // TODO: better lds descriptor for performance + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return a_lds_block_desc_0; + } + else + { + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + } + + /** + * @brief Create LDS block descriptor for B tensor. + * + * @tparam Problem Gemm pipeline problem. + * @return B tensor LDS block descriptor. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + using BDataType = remove_cvref_t; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + +#if 1 + if constexpr(is_b_load_tr) + { + // TODO: better lds descriptor for performance + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( // + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + return b_lds_block_desc_0; + } + else + // else if constexpr(std::is_same_v) + { + constexpr index_t KPack = GetSmemPackB(); + constexpr auto BK0 = number{}; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple( + BK0 * number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + BK0 * number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(number{}, BK0)), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(BK0, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; + } +#else + else // B is Row Major + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t VecLoadSize = GetVectorSizeB(); + using TileEncodingPattern = TileDistributionEncodingPattern2D; + + constexpr auto BK0 = number{}; + constexpr auto BK1 = number{}; + // constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N0 = TileEncodingPattern::X0; + constexpr auto N1 = NPerBlock / N0; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + constexpr auto NPerXdl = number{}; + + // constexpr auto KThreadWrite = + // BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto KThreadWrite = TileEncodingPattern::Y2; + constexpr auto K0PerThreadWrite = BK0 / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0 / KThreadRead; + + constexpr auto kfold = + (BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1 * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + BK1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_xor_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(number{}), + make_pass_through_transform(number{}), + make_unmerge_transform(make_tuple(number{}, number{})), + make_unmerge_transform(make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(BK1)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<1>{}, + sequence<2>{}, + sequence<0, 3>{}, + sequence<4, 5>{}, + sequence<6>{}, + sequence<7>{})); + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + // b_lds_block_desc_unmerged, + // make_tuple(make_merge_transform_v3_division_mod( + // make_tuple(number{}, + // number{}, + // number{}, + // number{})), + // make_merge_transform_v3_division_mod( + // make_tuple(number{}, number{}, number{})), + // make_pass_through_transform(BK1)), + // make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}), + // make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, + number{}, + number{}, + number{}, + BK1)), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // return b_lds_block_desc_bk0_n_bk1; + return b_lds_block_desc_kn; + + // constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor( + // make_tuple(BK0, number{}, number{}), + // make_tuple(number{}, number{}, number<1>{}), + // number{}, + // number<1>{}); + + // constexpr auto b_lds_block_desc = transform_tensor_descriptor( + // b_lds_block_desc_bk0_n_bk1, + // make_tuple(make_pass_through_transform(number{}), + // make_merge_transform_v3_division_mod(make_tuple(BK0, + // number{}))), + // make_tuple(sequence<1>{}, sequence<0, 2>{}), + // make_tuple(sequence<0>{}, sequence<1>{})); + + // return b_lds_block_desc; + } +#endif + } + + /** + * @brief Get the maximum global memory vector load size. + * + * @tparam Problem The UniversalGemmPipelineProblem object. + * @tparam DataType The tensor data type we're considering. + * @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B). + * @tparam XPerTile The contiguous Tile dimension size. + * @return Maximum DRAM vector load size. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + constexpr index_t PackedSize = + ck_tile::numeric_traits>::PackedSize; + + // Assume DataType is even! + if constexpr(XPerTile % (PackedSize * 32 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 32 / sizeof(DataType)) == 0 && + PackedSize == 2) + { + return (PackedSize * 32 / sizeof(DataType)); + } + else if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0) + { + return (PackedSize * 16 / sizeof(DataType)); + } + else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0) + { + return (PackedSize * 8 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 4 && + XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0) + { + return (PackedSize * 4 / sizeof(DataType)); + } + else if constexpr(sizeof(DataType) >= PackedSize * 2 && + XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 && + elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0) + { + return (PackedSize * 2 / sizeof(DataType)); + } + else + { + return PackedSize; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + return GetGlobalVectorLoadSize(); + } + else + { + return GetGlobalVectorLoadSize(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + return GetGlobalVectorLoadSize(); + } + else + { + return GetGlobalVectorLoadSize(); + } + } + + /** + * @brief Get the vector store size for C tensor. + * + * @tparam Problem - Gemm pipeline problem class. + * + * @note The vector store size for output C tensor would depend on multiple factors + * like its data layout and warp gemm C transposition. In general it would + * be the number of consecutive elements in contiguous C dimension hold by + * single thread. + * + * @return The vector store size for C tensor. + */ + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() + { + using BlockGemm = remove_cvref_t())>; + using WG = typename BlockGemm::WarpGemm; + + constexpr bool TransposeC = Problem::TransposeC; + using CLayout = typename Problem::CLayout; + using CWarpDstr = typename WG::CWarpDstr; + + // N is contiguous dimension + if constexpr(std::is_same_v) + { + if constexpr(TransposeC) + { + // In this case each thread has multiple consecutive elements in + // N dimension, however consecutive threads' elements have stride. + constexpr index_t NDimY = CWarpDstr::NDimY; + constexpr auto c_warp_y_lengths = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); + static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == + c_warp_y_lengths.get(number{})); + return c_warp_y_lengths.get(number{}); + } + else + { + // In this case each thread has just a single item in Ndim + return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + } + } + // M is contiguous dimension + else if constexpr(std::is_same_v) + { + if constexpr(TransposeC) + { + // In this case each thread has just a single item in Mdim + return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN; + } + else + { + // In this case each thread has multiple consecutive elements in + // M dimension, however consecutive threads' elements have stride. + constexpr index_t NDimY = CWarpDstr::NDimY; + constexpr auto c_warp_y_lengths = + CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); + static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == + c_warp_y_lengths.get(number{})); + return c_warp_y_lengths.get(number{}); + } + } + else + { + static_assert(false, "Unsupported CLayout!"); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() + { + return Problem::TransposeC; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + // Tile: MPerBlock X KPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + // Tile: KPerBlock X MPerBlock + else + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BLayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + // Tile: KPerBlock X NPerBlock + if constexpr(std::is_same_v) + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + // Tile: NPerBlock X KPerBlock + else + { + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::Make2DStaticTileDistribution(); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() + { + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeA(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() + { + using BLayout = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t VecLoadSize = GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + using TileEncodingPattern = TileDistributionEncodingPattern2D; + return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + } + + template + __host__ __device__ static constexpr auto MakeBiasDramTileDistribution() + { + + using BlockGemm = remove_cvref_t())>; + + using WG = typename BlockGemm::WarpGemm; + constexpr index_t MWarp = Problem::BlockGemmShape::w_kM; + constexpr index_t NWarp = Problem::BlockGemmShape::w_kN; + // constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + // constexpr index_t kBlockSize = Problem::kBlockSize; + // Gemm0 should be transposed + + + constexpr index_t M1 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M0 = MWarp; + // constexpr index_t M2 = kMPerBlock/(WG::WarpGemmAttribute::Impl::kM*M0); + constexpr index_t N4 = WG::WarpGemmAttribute::Impl::kCM0PerLane; + constexpr index_t N3 = WG::WarpGemmAttribute::Impl::kCNLane; + // constexpr index_t N2 = WG::WarpGemmAttribute::Impl::kCM1PerLane; + constexpr index_t N1 = NWarp; + constexpr index_t N0 = kNPerBlock / (N1 * WG::WarpGemmAttribute::Impl::kN); + // //if(get_thread_global_1d_id() ==64) + // //printf("%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,\n",MWarp,NWarp,kNPerBlock,M1,M0,N4,N3,N2,N1,N0); + // return make_static_tile_distribution( + // tile_distribution_encoding, + // tuple>, + // tuple, sequence<0, 1>>, + // tuple, sequence<2, 3>>, + // sequence<1, 1, 1>, + // sequence<2, 0, 4>>{}); + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using BlockGemm = remove_cvref_t())>; + constexpr index_t KPack = BlockGemm::Traits::KPack; + return KPack; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BlockGemm = remove_cvref_t())>; + constexpr index_t KPack = BlockGemm::Traits::KPack; + return KPack; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr auto a_lds_desc = MakeALdsBlockDescriptor(); + constexpr index_t smem_size_a = integer_least_multiple( + sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() + { + constexpr auto b_lds_desc = MakeBLdsBlockDescriptor(); + constexpr index_t smem_size_b = integer_least_multiple( + sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16); + return smem_size_b; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + constexpr index_t smem_size_b = GetSmemSizeB(); + + return smem_size_a + smem_size_b; + } +}; + +// UniversalGemm Policy +struct UniversalGemmPipelineAgBgCrPolicy_bias + : public UniversalGemmBasePolicy_bias +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + + constexpr index_t vector_size = + DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); + constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size(); + constexpr auto wg_attr_num_access = + !(is_a_load_tr || is_b_load_tr) ? WGAttrNumAccessEnum::Single + : vector_size == thread_elements ? WGAttrNumAccessEnum::Single + : vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double + : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad + : WGAttrNumAccessEnum::Invalid; + + using WarpGemm = WarpGemmMfmaDispatcher; + using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + return BlockUniversalGemmAsBsCr{}; + } +}; + +} // namespace ck_tile diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bool_switch.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bool_switch.hpp new file mode 100755 index 000000000000..156dc84b4569 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_bool_switch.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#define BOOL_SWITCH(COND1, CONST_NAME1, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + __VA_ARGS__(); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + __VA_ARGS__(); \ + } \ + }() + +#define BOOL_SWITCH_2(COND1, CONST_NAME1, COND2, CONST_NAME2, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + }() + +#define BOOL_SWITCH_3(COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + }() + +#define BOOL_SWITCH_4( \ + COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_3( \ + COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_3( \ + COND2, CONST_NAME2, COND3, CONST_NAME3, COND4, CONST_NAME4, ##__VA_ARGS__); \ + } \ + }() + +#define KK_SWITCH(KK_SZ, CONST_NAME, ...) \ + [&] { \ + constexpr ck_tile::index_t CONST_NAME = 512; \ + __VA_ARGS__(); \ + }() + +#define KN_KK_SWITCH(KN_SZ, CONST_NAME1, KK_SZ, CONST_NAME2, ...) \ + [&] { \ + if(KN_SZ <= 128) \ + { \ + constexpr ck_tile::index_t CONST_NAME1 = 128; \ + KK_SWITCH(KK_SZ, CONST_NAME2, ##__VA_ARGS__); \ + } \ + else if (KN_SZ <= 2880) \ + { \ + constexpr ck_tile::index_t CONST_NAME1 = 2880; \ + KK_SWITCH(KK_SZ, CONST_NAME2, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr ck_tile::index_t CONST_NAME1 = 5120; \ + KK_SWITCH(KK_SZ, CONST_NAME2, ##__VA_ARGS__); \ + } \ + }() + +#define KM_KN_KK_SWITCH( \ + KM_SZ, CONST_NAME1, KN_SZ, CONST_NAME2, KK_SZ, CONST_NAME3, ...) \ + [&] { \ + if(KM_SZ <= 128) \ + { \ + constexpr ck_tile::index_t CONST_NAME1 = 128; \ + KN_KK_SWITCH(KN_SZ, CONST_NAME2, KK_SZ, CONST_NAME3, ##__VA_ARGS__); \ + } \ + else if(KM_SZ <= 512) \ + { \ + constexpr ck_tile::index_t CONST_NAME1 = 512; \ + KN_KK_SWITCH(KN_SZ, CONST_NAME2, KK_SZ, CONST_NAME3, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr ck_tile::index_t CONST_NAME1 = 1024; \ + KN_KK_SWITCH(KN_SZ, CONST_NAME2, KK_SZ, CONST_NAME3, ##__VA_ARGS__); \ + } \ + }() diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_setting.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_setting.hpp new file mode 100755 index 000000000000..d81390bc4a6d --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_setting.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +template +struct GemmTileShapeConfig +{ + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + +template +struct GemmTileShapeConfig<128, kN, MaxK> +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + +template +struct GemmTileShapeConfig<512, 128, MaxK> +{ + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + +template +struct GemmTileShapeConfig<512, 2880, MaxK> +{ + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + +template +struct GemmTileShapeConfig<512, 5120, MaxK> +{ + static constexpr ck_tile::index_t M_Tile = 64; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + +template +struct GemmTileShapeConfig<1024, 128, MaxK> +{ + static constexpr ck_tile::index_t M_Tile = 32; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + +template +struct GemmTileShapeConfig<1024, kN, MaxK> +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + + diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_utils.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_utils.hpp new file mode 100755 index 000000000000..7b0b7efa7c23 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/gemm_utils.hpp @@ -0,0 +1,501 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "wp_bias_pipeline_agmem_bgmem_creg_v2.hpp" +#include "gemm_bias_pipeline_ag_bg_cr_comp_v3.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "gemm_bias_kernel.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 +#define CK_TILE_PIPELINE_PRESHUFFLE_V1 5 +#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 + +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool Preshuffle = false; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmConfigPreshuffle_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V1; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = false; +}; + +template +struct GemmConfigPreshuffle_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; + + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3_bias; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3_bias; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2_bias; + template + using UniversalGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2_bias; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent") + .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") + .insert("rotating_count", "1", "rotating count, defaults to 1"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// Type aliases for memory operation integral constants +using MemoryOpSet = + std::integral_constant; +using MemoryOpAtomicAdd = std::integral_constant; + +// host API +template +float gemm(const ck_tile::GemmHostArgs_bias& args, const ck_tile::stream_config& s); diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/tile_gemm_shape_bias.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/tile_gemm_shape_bias.hpp new file mode 100755 index 000000000000..fa2dcb3f6b1e --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/tile_gemm_shape_bias.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" + +namespace ck_tile { + +template +struct TileGemmShape_bias +{ + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); + + static constexpr index_t kM = BlockTile::at(number<0>{}); + static constexpr index_t kN = BlockTile::at(number<1>{}); + static constexpr index_t kK = BlockTile::at(number<2>{}); + + static constexpr index_t w_kM = BlockWarps::at(number<0>{}); + static constexpr index_t w_kN = BlockWarps::at(number<1>{}); + static constexpr index_t w_kK = BlockWarps::at(number<2>{}); + + static constexpr bool PermuteA = PermuteA_; + static constexpr bool PermuteB = PermuteB_; + + static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{}); + static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{}); + static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{}); + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + return concat('_', "tile_gemm_shape", + concat('x', kM, kN, kK, NumWarps), + concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})), + concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{}))); + // clang-format on + } +}; + +} // namespace ck_tile diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/universal_gemm_bias_kernel.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/universal_gemm_bias_kernel.hpp new file mode 100755 index 000000000000..15bdf6e42fde --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/universal_gemm_bias_kernel.hpp @@ -0,0 +1,1192 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The Universal GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref UniversalGemmKernel "UniversalGemmKernel" when creating +/// kernel arguments object. It contain all necessary information required to build proper +/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required). +/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required). +/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not +/// required). +template +struct UniversalGemmHostArgs_bias +{ + CK_TILE_HOST UniversalGemmHostArgs_bias(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + const void* bias_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : as_ptr(as_ptr_), + bs_ptr(bs_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + bias_ptr(bias_ptr_), + M(M_), + N(N_), + K(K_), + stride_As(stride_As_), + stride_Bs(stride_Bs_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + const void* bias_ptr; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +/// @brief The GEMM kernel device arguments. +template +struct UniversalGemmKernelArgs_bias +{ + /// @brief The As input tensor's pointer to device memory. + const std::array as_ptr; + /// @brief The Bs input tensor's pointer to device memory. + const std::array bs_ptr; + /// @brief The Ds input tensor's pointer to device memory. + const std::array ds_ptr; + /// @brief The E output tensor's pointer to device memory. + void* e_ptr; + const void* bias_ptr; + /// @brief GEMM's M dimension size. + index_t M; + /// @brief GEMM's N dimension size. + index_t N; + /// @brief GEMM's K dimension size. + index_t K; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of As tensor. + std::array stride_As; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Bs tensor. + std::array stride_Bs; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of Ds tensor. + std::array stride_Ds; + /// @brief The distance between consecutive elements of non-contiguous dimension + /// (in memory) of E tensor. + index_t stride_E; + index_t k_batch; +}; + +/// @brief The Universal GEMM kernel template. +/// +/// @paragraph Overview Overview +/// This class provides the generic matrix multiplication kernel template. By semantic +/// division of GEMM algorithm into following parts we achieve flexible, versatile +/// and robust kernel implementation. +/// +/// @li @b Prolog - The start of GEMM kernel implementation in @ref operator() +/// function call operator" which determines the work scope of each workgroup. +/// @li @b GemmPipeline - The core part @a "heart" of matrix multiplication algorithm. +/// This is the place where each workgroup is loading data from global memory and +/// carrying out dot products. +/// @li @b Epilogue - The @a "final" part of matrix multiplication implementation +/// responsible for storing results to global memory. This is also the place where +/// any additional operator fusion may take place. +/// +/// Additionally both @ref GemmPipeline_ "GemmPipeline" and @ref EpiloguePipeline_ +/// "EpiloguePipeline" are parameterized with so called @a Policy which determines all +/// internal details of those functional parts. You can think of it like both gemm and +/// epilogue pipelines provides the control-flow logic controlled by policies. Moreover +/// the policy is responsible for definition of all necessary data layouts and thread's +/// work distribution. +/// +/// @tparam TilePartitioner_ The type of class providing mapping of workgroup index into the +/// output data tile to be calculated. It determines the workgroup to +/// data relationship (or in other words - which data would be +/// processed and calculated by which workgroup). +/// @tparam GemmPipeline_ The type of class which provides the core part of matrix +/// multiplication. This class should provide implementation of data +/// loading from global memory and performing block-wise matrix +/// multiplication. You can think of it as a work done by single +/// workgroup point of view. +/// @tparam EpiloguePipeline_ The type of class providing the final part of matrix +/// multiplication implementation. It is responsible for storing +/// results calculated by @ref GemmPipeline_ "GemmPipeline" to +/// the output E tensor in global memory. +template +struct UniversalGemmKernel_bias +{ + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = + is_detected::value; + static constexpr bool BDataTypeIsTuple = + is_detected::value; + static constexpr bool DDataTypeIsTuple = + is_detected::value; + static constexpr bool ALayoutIsTuple = + is_detected::value; + static constexpr bool BLayoutIsTuple = + is_detected::value; + static constexpr bool DLayoutIsTuple = + is_detected::value; + + using AsLayout = std::conditional_t, + remove_cvref_t>>; + using BsLayout = std::conditional_t, + remove_cvref_t>>; + + using DsLayout = std::conditional_t, + remove_cvref_t>>; + + using AsDataType = std::conditional_t, + remove_cvref_t>>; + + using BsDataType = std::conditional_t, + remove_cvref_t>>; + + using DsDataType = + std::conditional_t, + remove_cvref_t>>; + + using ELayout = remove_cvref_t; + using EDataType = remove_cvref_t; + + static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize; + + // Get the persistent kernel if the pipeline has it available + struct has_persistent_kernel + { + template + using has_persistent_type = decltype(T::UsePersistentKernel); + + static constexpr bool value = []() { + if constexpr(is_detected{}) + return GemmPipeline::UsePersistentKernel; + else + return false; + }(); + }; + static constexpr bool PersistentKernel = has_persistent_kernel::value; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto I3 = number<3>{}; + static constexpr auto I4 = number<4>{}; + + static constexpr index_t NumATensor = AsDataType::size(); + static constexpr index_t NumBTensor = BsDataType::size(); + static constexpr index_t NumDTensor = DsDataType::size(); + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + static_assert(AsLayout::size() == AsDataType::size(), + "The size of AsLayout and AsDataType should be the same"); + + static_assert(BsLayout::size() == BsDataType::size(), + "The size of BsLayout and BsDataType should be the same"); + + static_assert(DsLayout::size() == DsDataType::size(), + "The size of DsLayout and DsDataType should be the same"); + + using KernelArgs = + UniversalGemmKernelArgs_bias; + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "gemm", gemm_prec_str(), GemmPipeline::GetName()); + // clang-format on + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) + { + return dim3(TilePartitioner::GridSize(M, N), 1, KBatch); + } + + /** + * @brief Get the maximum occupancy grid size for the persistent kernel on the current device. + * @return The maximum occupancy grid size. + * @note This function queries the maximum occupancy of the kernel using + * `hipOccupancyMaxActiveBlocksPerMultiprocessor`. + */ + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + using Kernel = UniversalGemmKernel_bias; + const auto kernel = kentry; + int occupancy; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0)); + const int grid_size = get_available_compute_units(s) * occupancy; + return dim3(grid_size, 1, 1); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); } + + CK_TILE_HOST static constexpr KernelArgs + MakeKernelArgs(const UniversalGemmHostArgs_bias& hostArgs) + { + return KernelArgs{hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.bias_ptr, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, + hostArgs.stride_E, + hostArgs.k_batch}; + } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z) + { + constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{}); + const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.k_batch * K1); + const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.K + K_t - 1) / K_t * K1); + + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + as_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + else if constexpr(std::is_same_v) + { + as_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_As[index]); + } + }); + + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = + __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.stride_Bs[index]); + } + else if constexpr(std::is_same_v) + { + bs_k_split_offset[index] = __builtin_amdgcn_readfirstlane(k_id * KRead); + } + }); + + if(k_id < static_cast(kargs.k_batch - 1)) + { + splitted_k = __builtin_amdgcn_readfirstlane(KRead); + } + else + { + splitted_k = __builtin_amdgcn_readfirstlane(kargs.K - KRead * (kargs.k_batch - 1)); + } + } + + std::array as_k_split_offset; + std::array bs_k_split_offset; + index_t splitted_k; + }; + + CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs) + { + if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value) + { + if(kargs.k_batch != 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + } + return false; + } + } + + bool AsTesnorIsValid = {true}; + static_for<0, NumATensor, 1>{}([&](auto index) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + AsTesnorIsValid = false; + } + if(kargs.M % GemmPipeline::GetVectorSizeA() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for A tensor!"); + } + AsTesnorIsValid = false; + } + } + }); + + bool BsTesnorIsValid = {true}; + static_for<0, NumBTensor, 1>{}([&](auto index) { + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.N % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + else + { + if(kargs.K % (TilePartitioner::KPerBlock * kargs.k_batch) != 0 && + GemmPipeline::kPadK == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support K that is not a multiple of k_batch * KPerBlock " + "without padding!"); + } + BsTesnorIsValid = false; + } + if(kargs.K % GemmPipeline::GetVectorSizeB() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("K is not a multiple of vector load size for B tensor!"); + } + BsTesnorIsValid = false; + } + } + }); + + bool DTesnorIsValid = {true}; + static_for<0, NumDTensor, 1>{}([&](auto index) { + using DiLayout = remove_cvref_t>; + if(std::is_same_v == false) + { + DTesnorIsValid = false; + } + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of " + "NPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of " + "MPerBlock without padding!"); + } + DTesnorIsValid = false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!"); + } + DTesnorIsValid = false; + } + } + }); + + if constexpr(std::is_same_v) + { + if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support N that is not a multiple of NPerBlock without padding!"); + } + return false; + } + if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("N is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + else + { + if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Can't support M that is not a multiple of MPerBlock without padding!"); + } + return false; + } + if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("M is not a multiple of vector load size for C tensor!"); + } + return false; + } + } + return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; + } + + template + CK_TILE_DEVICE static auto + MakeGemmTensorViews(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + const EDataType* bias_ptr, + EDataType* e_ptr, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset) + { + static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); + + const auto& as_tensor_view = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + using AiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(as_ptr[i]), + make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(kargs.stride_As[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + const auto& bs_tensor_view = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + using BiDataType = remove_cvref_t>; + if constexpr(std::is_same_v) + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(TilePartitioner::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = splitk_batch_offset.splitted_k / K1; + constexpr index_t VectorSizeB = + std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view( + static_cast(bs_ptr[i]), b_n_k_desc); + } + else + { + if constexpr(GemmPipeline::Preshuffle) + { + index_t kFlatK = + GemmPipeline::BlockGemmShape::flatKPerWarp * + (splitk_batch_offset.splitted_k / + TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + bs_ptr[i], + make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.stride_Bs[i], 1), + number{}, + number<1>{}); + } + } + } + }, + number{}); + + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // TODO: enable vector write for C in ColMajor + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_E), + number<1>{}, + number<1>{}); + } + }(); + const auto& bias_tensor_view = [&]() { + + return make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.N), + make_tuple(1), + number{}, + number<1>{}); + + }(); + return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view,bias_tensor_view); + } + + template + CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) + { + const auto& as_pad_view = generate_tuple( + [&](auto i) { + const auto& a_tensor_view = views.at(I0); + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& b_flat_pad_view = views.at(I1); + + const auto& bs_pad_view = generate_tuple( + [&](auto i) { + const auto& b_tensor_view = views.at(I1); + using BiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + const auto& d_tensor_view = views.at(I2); + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(d_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // TODO vector write in for C in ColMajor + const auto& e_pad_view = [&]() { + const auto& e_tensor_view = views.at(I3); + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + const auto& bias_pad_view = [&]() { + const auto& bias_tensor_view = views.at(I4); + return pad_tensor_view(bias_tensor_view, make_tuple(number{}), sequence{}); + }(); + if constexpr(GemmPipeline::Preshuffle) + { + // For flatmm, we need to use the flat B tensor view + return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view,bias_pad_view); + } + else + { + return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view,bias_pad_view); + } + } + + template + CK_TILE_DEVICE static auto + MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) + { + const auto& as_pad_view = views.at(I0); + const auto& bs_pad_view = views.at(I1); + const auto& ds_pad_view = views.at(I2); + const auto& e_pad_view = views.at(I3); + const auto& bias_pad_view = views.at(I4); + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_m}); + } + }, + number{}); + + const auto& bs_block_window = generate_tuple( + [&](auto i) { + using BiLayout = remove_cvref_t>; + if constexpr(GemmPipeline::Preshuffle) + { + return make_tile_window( + bs_pad_view[i], + make_tuple(number{}, + number{}), + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), + 0}); + } + else + { + if constexpr(std::is_same_v) + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(bs_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_n}); + } + } + }, + number{}); + + const auto ds_block_window = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {i_n, i_m}); + } + }, + number{}); + + auto e_block_window = make_tile_window( + e_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + auto bias_block_window = make_tile_window( + bias_pad_view, + make_tuple(number{}), {i_n}); + return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window,bias_block_window); + } + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + template + CK_TILE_DEVICE static void RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + const EDataType* bias_ptr, + EDataType* e_ptr, + void* smem_ptr_0, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr,bias_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + const auto& bias_block_window = gemm_tile_windows.at(I4); + + const auto& c_block_tile = + GemmPipeline{}(as_block_window[I0], bs_block_window[I0],bias_block_window, num_loop, smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + /** + * @brief Runs single GEMM problem cooperatively by whole workgroup. + * + * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * + * @param as_ptr input As pointer + * @param bs_ptr input Bs pointer + * @param ds_ptr input Ds pointer + * @param e_ptr output E pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. + * @param kargs GEMM kernel arguments + * @param splitk_batch_offset Utility structure used to calculate k batch. + * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. + * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. + * + */ + CK_TILE_DEVICE static void RunGemm2LDS(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + const EDataType* bias_ptr, + EDataType* e_ptr, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, + const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr,bias_ptr, e_ptr, kargs, splitk_batch_offset); + + const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + const index_t num_loop = __builtin_amdgcn_readfirstlane( + TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(I0); + const auto& bs_block_window = gemm_tile_windows.at(I1); + const auto& ds_block_window = gemm_tile_windows.at(I2); + const auto& bias_block_window = gemm_tile_windows.at(I4); + const auto& c_block_tile = GemmPipeline{}( + as_block_window[I0], bs_block_window[I0],bias_block_window, num_loop, smem_ptr_0, smem_ptr_1); + + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + + // Non-persistent kernel entry point + template > + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockId); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + const SplitKBatchOffset splitk_batch_offset(kargs); + + // options + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + static_cast(kargs.bias_ptr), + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + static_cast(kargs.bias_ptr), + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + } + + // Persistent kernel entry point + template , typename = void> + CK_TILE_DEVICE void operator()(KernelArgs kargs) const + { + const auto grid_size = __builtin_amdgcn_readfirstlane(get_grid_size()); + const auto num_tiles = + __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.M, kargs.N)); + const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.k_batch); + auto block_id = __builtin_amdgcn_readfirstlane(get_block_id()); + + while(block_id < num_work) + { + // Get the tile index for this block + const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles); + const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(tile_idx); + const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock); + const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock); + + // Get the SplitK offset for this block + const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles); + const SplitKBatchOffset splitk_batch_offset(kargs, k_batch); + + std::array as_ptr; + static_for<0, NumATensor, 1>{}([&](auto i) { + as_ptr[i] = static_cast(kargs.as_ptr[i]) + + splitk_batch_offset.as_k_split_offset[i]; + }); + + std::array bs_ptr; + static_for<0, NumBTensor, 1>{}([&](auto i) { + bs_ptr[i] = static_cast(kargs.bs_ptr[i]) + + splitk_batch_offset.bs_k_split_offset[i]; + }); + + EDataType* e_ptr = static_cast(kargs.e_ptr); + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + // Run the GEMM + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + static_cast(kargs.bias_ptr), + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + is_any_of::value)) + { + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + static_cast(kargs.bias_ptr), + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + } + // Advance to the next work item + block_id += grid_size; + if(block_id >= num_work) + { + break; + } + } + } +}; +} // namespace ck_tile diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/wp_bias_pipeline_agmem_bgmem_creg_base_policy.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/wp_bias_pipeline_agmem_bgmem_creg_base_policy.hpp new file mode 100755 index 000000000000..c5cc74eec4c8 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/wp_bias_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "gemm_bias_universal_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1.hpp" +namespace ck_tile { + +struct UniversalWeightPreshufflePipelineAgBgCrPolicy_bias + : public UniversalGemmBasePolicy_bias +{ + using BasePolicy = UniversalGemmBasePolicy_bias; + + // 3d + padding + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + using namespace ck_tile; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = GetSmemPackA(); + using ADataType = remove_cvref_t; + + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / kKPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / kKPerBlock / DataTypeSize); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple( + make_merge_transform( + make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) * + MakeALdsBlockDescriptor().get_element_space_size(); + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + constexpr index_t smem_size_a = GetSmemSizeA(); + + return smem_size_a; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + return Problem::VectorLoadSize / sizeof(typename Problem::ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad() + { + using TileShape = typename Problem::BlockGemmShape; + if constexpr(TileShape::WarpTile::at(I1) == 32) + { + return TileShape::WarpTile::at(I2) / 2; + } + else + { + static_assert(TileShape::WarpTile::at(I1) == 16); + return TileShape::WarpTile::at(I2) / 4; + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() >= (K2 * M0)) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M2, M1 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + static_assert(M0 * M1 * M2 == MPerBlock, + "Incorrect M0, M1, M2 configuration! " + "M0, M1, M2 must cover whole MPerBlock!"); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t KBPerLoad = GetKBPerLoad(); + constexpr index_t KThdPerWave = WaveSize; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t KRepeat = 1; + static_assert(TileShape::flatKPerWarp == KThdPerWave * KBPerLoad, "wrong"); + + constexpr index_t NBPerLoad = 1; + constexpr index_t NThdPerWave = 1; + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + constexpr index_t NRepeat = 1; + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, // ? + tuple, // second direction + sequence>, // first direction + // wave in blk, // thd in wave + // // + tuple, sequence<1, 2>>, // which direction + tuple, sequence<2, 2>>, // which index + // + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = kMPerBlock / M1; + constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackA(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size >= (K2 * M0)) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + __host__ __device__ static constexpr auto MakeBiasDramTileDistribution() + { + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WG = WarpGemmMfmaDispatcher;//typename BlockGemm::WarpGemm; + constexpr index_t MWarp = Problem::BlockGemmShape::w_kM; + constexpr index_t NWarp = Problem::BlockGemmShape::w_kN; + // constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + // constexpr index_t kBlockSize = Problem::kBlockSize; + // Gemm0 should be transposed + + + constexpr index_t M1 = WG::WarpGemmAttribute::Impl::kCMLane; + constexpr index_t M0 = MWarp; + // constexpr index_t M2 = kMPerBlock/(WG::WarpGemmAttribute::Impl::kM*M0); + constexpr index_t N4 = WG::WarpGemmAttribute::Impl::kCM0PerLane; + constexpr index_t N3 = WG::WarpGemmAttribute::Impl::kCNLane; + // constexpr index_t N2 = WG::WarpGemmAttribute::Impl::kCM1PerLane; + constexpr index_t N1 = NWarp; + constexpr index_t N0 = kNPerBlock / (N1 * WG::WarpGemmAttribute::Impl::kN); + // //if(get_thread_global_1d_id() ==64) + // //printf("%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,\n",MWarp,NWarp,kNPerBlock,M1,M0,N4,N3,N2,N1,N0); + // return make_static_tile_distribution( + // tile_distribution_encoding, + // tuple>, + // tuple, sequence<0, 1>>, + // tuple, sequence<2, 3>>, + // sequence<1, 1, 1>, + // sequence<2, 0, 4>>{}); + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle() + { + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmMfmaDispatcher; + + using BlockWeightPreshufflePolicy = + BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy; + return BlockWeightPreshuffleASmemBSmemCRegV1{}; + } +}; + +} // namespace ck_tile diff --git a/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/wp_bias_pipeline_agmem_bgmem_creg_v2.hpp b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/wp_bias_pipeline_agmem_bgmem_creg_v2.hpp new file mode 100755 index 000000000000..7afb0608b461 --- /dev/null +++ b/csrc/rocm/ck_extensions/ck_tile_gemm_bf16/include/wp_bias_pipeline_agmem_bgmem_creg_v2.hpp @@ -0,0 +1,1085 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/concat.hpp" +#include "wp_bias_pipeline_agmem_bgmem_creg_base_policy.hpp" + +namespace ck_tile { + +template +struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2_bias +{ + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + CK_TILE_HOST_DEVICE static auto + TailHandler(const RunFunction& run_func, bool, TailNumber tail_number) + { + if(tail_number == TailNumber::Odd) + { + run_func(bool_constant{}, integral_constant{}); + } + else if(tail_number == TailNumber::Even) + { + run_func(bool_constant{}, integral_constant{}); + } + } +}; + +template +struct WeightPreshufflePipelineAGmemBGmemCRegV2_bias + : public BaseWeightPreshufflePipelineAGmemBGmemCRegV2_bias +{ + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2_bias; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockWeightPreshuffle = + remove_cvref_t())>; + + static constexpr auto config = + BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + static constexpr index_t BlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp; + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() + { + return PipelinePolicy::template GetVectorSizeA(); + } + static constexpr index_t GetVectorSizeB() + { + return PipelinePolicy::template GetVectorSizeB(); + } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr index_t kLdsAlignmentInBytes = 16; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + static constexpr index_t ACopyLoadNum = kMPerBlock * kKPerBlock / BlockSize / K1; + static constexpr auto TailNum = Problem::TailNum; + + static constexpr auto warp_m = WarpTile::at(idxM); + static constexpr auto warp_n = WarpTile::at(idxN); + static constexpr auto warp_k = WarpTile::at(idxK); + + [[nodiscard]] CK_TILE_HOST static const std::string GetName() + { + // clang-format off + return concat('_', "pipeline_AGmemBGmemCRegV2", + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize), + concat('x', WG::kM, WG::kN, WG::kK), + concat('x', GetVectorSizeA(), GetVectorSizeB()), + concat('x', kPadM, kPadN, kPadK)); + + // clang-format on + } + + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t Preshuffle = Problem::Preshuffle; + using Base::UsePersistentKernel; + + CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return PipelinePolicy::template GetSmemSize(); + } + + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + + constexpr index_t KPerLoad = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t A_Buffer_Load_Inst_Num = kMPerBlock * kKPerBlock / BlockSize / KPerLoad; + constexpr index_t A_LDS_Read_Inst_Num = MIterPerWarp * KIterPerWarp; + constexpr index_t B_Buffer_Load_Inst_Num = NIterPerWarp * KIterPerWarp; + + // Keypoint of pipeline optimize is workload balance in time + // instruction schedule example(128X256X256, 1X4, 16X16X128): + // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N3: 60 2 - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 - - - - + // -1 M7N2: 63 - - - - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - - + // 0 M0N1: 2 - - - 2 + // 0 M0N2: 3 - - - - + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - - + // 0 M1N1: 6 - - - 4 + // 0 M1N2: 7 - - - - + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - - + // 0 M2N1: 10 - - - 6 + // 0 M2N2: 11 - - - - + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - - + // 0 M3N1: 14 - - - 8 + // 0 M3N2: 15 - - - - + // 0 M3N3: 16 12 - - - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 - - - - + // 0 M4N2: 19 - - 1 - + // 0 M4N3: 20 14 - - - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 - - - - + // 0 M5N2: 23 - - 2 - + // 0 M5N3: 24 16 - - - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 - - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 17 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 - - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 18 - - - + // 0 M0N0K1: 33 - - - - + // 0 M0N1: 34 - - - 10 + // 0 M0N2: 35 - - - - + // 0 M0N3: 36 20 - - - + // 0 M1N0: 37 - - - - + // 0 M1N1: 38 - - - 12 + // 0 M1N2: 39 - - - - + // 0 M1N3: 40 22 - - - + // 0 M2N0: 41 - - - - + // 0 M2N1: 42 - - - 14 + // 0 M2N2: 43 - - - - + // 0 M2N3: 44 24 - - - + // 0 M3N0: 45 - 5 - - + // 0 M3N1: 46 - - - 16 + // 0 M3N2: 47 - - - - + // 0 M3N3: 48 26 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 - - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 28 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 - - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 30 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 - - - - + // 0 M6N2: 59 - - 7 - + // 0 M6N3: 60 2 - - - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 - - - - + // 0 M7N2: 63 - - 8 - + // 0 M7N3: 64 4 - - - + + if constexpr(warp_m == 16 && warp_n == 16) + { +// MFMA -> VMEM READ -> MFMA -> DS Read -> MFMA +// hiding the glbal memory VMEM latency +#if defined(__gfx950__) + if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, 3, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } +// MFMA → MFMA → MFMA → MFMA → DS Read +// For other device engine we need more aggressive MFMA with DS writes interleaved +#else + if constexpr(kMPerBlock == 128 && kNPerBlock == 256 && kKPerBlock == 256) + { + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + // Uses loops to amortize scheduling overhead + static_for<0, 4, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(kMPerBlock == 16 && kNPerBlock == 64 && kKPerBlock == 256) + { + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(kMPerBlock == 128 && kNPerBlock == 128 && kKPerBlock == 128) + { + // prioritize MFMA to avoid LDS write conflicts + static_for<0, 2, 1>{}([&](auto j) { + ignore = j; + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + static_for<0, 1, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + }); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num - A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + }); + } + +#endif + } + else + { + if constexpr((A_LDS_Read_Inst_Num / 2 > + A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num)) + { + static_for<0, + A_LDS_Read_Inst_Num / 2 - A_Buffer_Load_Inst_Num - + B_Buffer_Load_Inst_Num, + 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_LDS_Read_Inst_Num / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 3, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x008, 4, 0); // MFMA + } + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1; + const index_t iMWarp = get_warp_id() / NWarp; + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + ADataType* p_a_lds_ping = static_cast(p_smem_ping); + ADataType* p_a_lds_pong = static_cast(p_smem_pong); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + auto a_copy_lds_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeADramTileDistribution()); + + // ping-pong window for A LDS + auto a_warp_window_ping_tmp = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + auto a_warp_window_pong_tmp = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {iMWarp * WG::kM, 0}, + make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_ping; + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows_pong; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp; + + move_tile_window(a_warp_windows_ping(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp; + + move_tile_window(a_warp_windows_pong(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Block GEMM + auto block_weight_preshuffle = BlockWeightPreshuffle(); + // Acc register tile + auto c_block_tile = block_weight_preshuffle.MakeCBlockTile(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + auto b_flat_dram_window = // tile_window_with_static_distribution + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + bias_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeBiasDramTileDistribution()); + // pingpong buffer for B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_flat_dram_windows; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_ping; + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_pong; + + // Prefetch A0 + auto a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + // Prefill A0 + auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // initialize C + auto bias = load_tile(bias_dram_window); + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + sweep_tile(c_block_tile, [&](auto idx) { + // compute x = bias + x + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + c_block_tile(idx) += type_convert(bias[j_idx]) ; + + }); + + block_sync_lds(); + + // preload A00,A10 from lds + constexpr auto m_preload = (MIterPerWarp * KIterPerWarp >= 2) ? 2 : 1; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_ping; + statically_indexed_array{})(number<0>{}))), + m_preload> + a_warp_tensor_pong; + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + __builtin_amdgcn_sched_barrier(0); + + index_t iCounter = (num_loop - 1) / 2; + while(iCounter > 0) + { + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+1) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // Prefetch A(2i+2) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + HotLoopScheduler(); + + // Next K + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(2i+2) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_ping, a_block_tile_tmp); + + // Prefetch A(2i+3) + a_block_tile = load_tile(a_copy_dram_window); + // move A window to next k + move_tile_window(a_copy_dram_window, {0, kKPerBlock}); + + // GEMM 2i+1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_ping(loadIter) = + load_tile(a_warp_windows_ping(number{})(number{})); + }); + HotLoopScheduler(); + + iCounter--; + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // __builtin_amdgcn_sched_barrier(0); + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window; + + move_tile_window(b_flat_dram_windows(nIter)(kIter), + {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); + + b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + }); + }); + + // Prefill A(loopK) + a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window_pong, a_block_tile_tmp); + + // GEMM loopK-1 + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + // TailHotLoopScheduler(); + + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MIterPerWarp; + constexpr auto kIter = loadIter / MIterPerWarp; + a_warp_tensor_pong(loadIter) = + load_tile(a_warp_windows_pong(number{})(number{})); + }); + + // __builtin_amdgcn_sched_barrier(0); + + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_pong(number{}), + b_warp_tensor_pong(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + __builtin_amdgcn_sched_barrier(0x7F6); + }); + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_pong(number{}) = + load_tile(a_warp_windows_pong(number{})(number{})); + } + }); + }); + // TailHotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload; + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}(c_warp_tensor, + a_warp_tensor_ping(number{}), + b_warp_tensor_ping(nIter)(kIter)); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + + __builtin_amdgcn_sched_barrier(0x7F6); + }); + // preload next A from lds + if constexpr((kIter * MIterPerWarp + mIter) < + (KIterPerWarp * MIterPerWarp - m_preload)) + { + constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp; + constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp); + a_warp_tensor_ping(number{}) = + load_tile(a_warp_windows_ping(number{})(number{})); + } + + // barrier + if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last)) + { + block_sync_lds(); + } + }); + }); + } + + return c_block_tile; + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType & a) { return a; }, + b_flat_dram_block_window_tmp, + bias_dram_block_window_tmp, + num_loop, + p_smem_ping, + p_smem_pong); + } +}; + +} // namespace ck_tile diff --git a/csrc/rocm/ck_extensions/include/ck/README.md b/csrc/rocm/ck_extensions/include/ck/README.md new file mode 100755 index 000000000000..92d5a5108736 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/README.md @@ -0,0 +1,23 @@ +[Back to the main page](../../README.md) +# Composable Kernel supported operations +## Supported device operations + + + + + + + + +* [GEMM](../../client_example/01_gemm/README.md) +* [Grouped Convolution Forward](../../client_example/07_grouped_convnd_fwd/README.md) +* [Grouped Convolution Backward Data](../../client_example/10_grouped_convnd_bwd_data/README.md) +* [Grouped Convolution Backward Weight](../../client_example/11_grouped_conv_bwd_weight/README.md) + + + + + + + + diff --git a/csrc/rocm/ck_extensions/include/ck/ck.hpp b/csrc/rocm/ck_extensions/include/ck/ck.hpp new file mode 100755 index 000000000000..794c6f4e2092 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/ck.hpp @@ -0,0 +1,303 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/config.h" + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS +#include "hip/hip_runtime.h" +#include "hip/hip_fp16.h" +#endif +#endif +// to do: add various levels of logging with CK_LOG_LEVEL + +#ifndef CK_TIME_KERNEL +#define CK_TIME_KERNEL 1 +#endif + +// constant address space for kernel parameter +// https://llvm.org/docs/AMDGPUUsage.html#address-spaces +#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4))) + +// launch bounds +#define CK_USE_LAUNCH_BOUNDS 1 + +#ifdef CK_USE_LAUNCH_BOUNDS +// for most kernels +#define CK_MAX_THREAD_PER_BLOCK 256 +#define CK_MIN_BLOCK_PER_CU 2 + +// for wavelet GEMM kernel +#define CK_WAVELET_MAX_THREAD_PER_BLOCK 512 +#define CK_WAVELET_MIN_BLOCK_PER_CU 2 +#endif + +// kernel attribute: amdgpu_waves_per_eu() +#ifdef CK_USE_WAVES_PER_EU +// for 1-wave kernels, control arguments of amdgpu_waves_per_eu() attribute +#ifndef CK_MIN_WAVES_PER_EU +#define CK_MIN_WAVES_PER_EU 0 +#endif + +#ifndef CK_MAX_WAVES_PER_EU +#define CK_MAX_WAVES_PER_EU 0 +#endif + +#else +#define CK_USE_WAVES_PER_EU 0 +#endif + +// define general macros for various architectures +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +#define __gfx9__ +#endif +#if defined(__gfx942__) || defined(__gfx950__) +#define __gfx94__ +#endif +#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__) +#define __gfx101__ +#endif +#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \ + defined(__gfx10_3_generic__) +#define __gfx103__ +#endif +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \ + defined(__gfx1152__) || defined(__gfx11_generic__) +#define __gfx11__ +#endif +#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) +#define __gfx12__ +#endif + +// buffer resource +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_BUFFER_RESOURCE_3RD_DWORD -1 +#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 +#elif defined(__gfx103__) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 +#elif defined(__gfx11__) || defined(__gfx12__) +#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31004000 +#endif + +// FMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing +#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code +#define CK_USE_AMD_V_MAC_F32 +#elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8 +#elif defined(__gfx11__) || defined(__gfx12__) +#define CK_USE_AMD_V_FMAC_F32 +#define CK_USE_AMD_V_DOT2_F32_F16 +#define CK_USE_AMD_V_DOT4_I32_I8_GFX11 +#endif + +// MFMA instruction +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_MFMA +#elif defined(__gfx9__) // for GPU code +#define CK_USE_AMD_MFMA +#endif + +#if(defined(__gfx90a__) || defined(__gfx94__)) +#define CK_USE_AMD_MFMA_BF16_1K_OP +#endif + +#if defined(__gfx94__) +#define CK_USE_AMD_MFMA_GFX940 +#endif + +// buffer load +#define CK_USE_AMD_BUFFER_LOAD 1 + +// buffer store +#define CK_USE_AMD_BUFFER_STORE 1 + +// buffer atomic add: integer +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER 1 + +// buffer atomic add: floating point +#ifndef __HIP_DEVICE_COMPILE__ // for host code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#elif defined(__gfx9__) || defined(__gfx12__) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 +#else // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 +#endif + +#if(defined(__gfx90a__) || defined(__gfx94__)) // for GPU code +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1 +#else +#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0 +#endif + +// inline asm +#define CK_USE_AMD_INLINE_ASM 1 + +// inner product (V_MAC/V_FMAC) +#define CK_USE_AMD_V_MAC_INLINE_ASM 1 + +// V_DOT inline instructions, less efficient since they require adding +// `s_nop`s to avoid hazard +#define CK_USE_AMD_V_DOT_INLINE_ASM 0 + +// inner product using V_DOT with DPP8 modifiers +#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 + +// LDS direct loads using inline assembly +#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 0 + +// set rounding to nearest even as default for bf16 conversions +#define CK_USE_RNE_BF16_CONVERSION 1 + +// set rounding to nearest even as default for f8 conversions +#define CK_USE_SR_F8_CONVERSION 0 + +// set rounding to nearest even as default for f6 conversions +#define CK_USE_SR_F6_CONVERSION 0 + +// set rounding to nearest even as default for f4 conversions +#define CK_USE_SR_F4_CONVERSION 0 + +// shuffle pk_i4 values during conversion to optimize number of binary +// operations +#define CK_USE_PK4_LAYOUT_SHUFFLE 1 + +// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) +#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 + +// experimental feature: multi index implemented as array +#define CK_EXPERIMENTAL_USE_DYNAMICALLY_INDEXED_MULTI_INDEX 0 + +// experimental feature: static tensor descriptor +#define CK_EXPERIMENTAL_STATIC_TENSOR_DESCRIPTOR 0 + +// experimental feature: buffer load/store/atomic-add/ OOB trick +// This (ifndef) is a hack to use customized behavior for buffer load rather than using default +// setting. Don't use this hack unless absolutely necessary! +// FIXME: make the behavior of buffer load a configurable (template) parameter for each usage +#ifndef CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 0 +#endif +#define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK 1 + +// experimental feature: in-regsiter sub-dword transpose +#define CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE 1 + +// experimental feature: merge transformation use magic number division +#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1 + +// experimental feature: use __builtin_memcpy instead of pointer cast to access a vector from +// pointer of scalar +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 0 + +// experimental feature: use __builtin_memcpy instead of union to do bit_cast +#define CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST 1 + +// experimental feature: optimize for inter-wave scheduling policy +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING 1 +#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS 1 +// this will let make_default_loop_scheduler() return interwave scheduling flag by default +#define CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING 0 +// experimental feature: add instances using interwave scheduling +#define CK_EXPERIMENTAL_INTER_WAVE_INSTANCES 1 +// experimental feature: add instances using pipeline v2 +#define CK_EXPERIMENTAL_PIPELINE_V2_INSTANCES 1 +// experimental feature: optimize pipeline v2 by IGLP strategy (value=ID of strategy) +#ifndef CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT +#define CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT 0 +#endif + +// hack: have underlying assumption that need to be satsified, otherwise it's a bug +// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be +// thread-invariant, otherwise it's a bug +// TODO: separate index calculation into "compile-time", "global", "block", "wave", "thread" +#define CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0 + +// workaround: conv crash when K, C is even +#define CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN 1 + +// workaround: compiler crash when compiling recursive lambda +#define CK_WORKAROUND_SWDEV_275126 1 + +// workaround: compiler crash when using buffer load/store for i8 +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE 1 + +// workaround: compiler gnerating inefficient ds_write instructions +#define CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1 + +// workaround: verifaction failure, due to compiler regression, for conv bwd-data fp16 using some +// tuning parameter +#define CK_WORKAROUND_SWDEV_325164 0 + +// workaround: compiler not emiting reciprocal instruction frm __frcp_rn() +#define CK_WORKAROUND_SWDEV_383542 1 + +// workaround: compiler issue on gfx908 +#define CK_WORKAROUND_SWDEV_388832 1 + +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_FP16_TO_FP8_CONVERSION 1 + +// workaround: compiler issue on gfx950 +#define CK_WORKAROUND_BF16_TO_FP8_CONVERSION 1 + +// denorm test fix, necessary for gfx90a +#ifndef CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // CK_GFX90A_DENORM_WORKAROUND +// Enable only for gfx90a +#if defined(__gfx90a__) +#if CK_GFX90A_DENORM_WORKAROUND +#define CK_GFX90A_DENORM_WORKAROUND 1 +#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1 +#else +#define CK_GFX90A_DENORM_WORKAROUND 0 +#endif // gfx90a + +// set flag to 1 to build deprecated instances +#define CK_BUILD_DEPRECATED 1 + +namespace ck { + +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) +__device__ static constexpr int WarpSize = 64; +#else +__device__ static constexpr int WarpSize = 32; +#endif + +enum struct InMemoryDataOperationEnum +{ + Set, + AtomicAdd, + AtomicMax, + Add +}; + +// FIXME: use regular Sequence and remove this +template +struct InMemoryDataOperationEnumSequence +{ + static constexpr int mSize = sizeof...(Is); + + __host__ __device__ static constexpr InMemoryDataOperationEnum At(int I) + { + // the last dummy element is to prevent compiler complain about empty array, when mSize = 0 + const InMemoryDataOperationEnum mData[mSize + 1] = {Is..., InMemoryDataOperationEnum::Set}; + return mData[I]; + } +}; + +// index type +using index_t = int32_t; +using long_index_t = int64_t; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/config.h.in b/csrc/rocm/ck_extensions/include/ck/config.h.in new file mode 100755 index 000000000000..306a6c2ff1ae --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/config.h.in @@ -0,0 +1,144 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2025 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_CONFIG_H_IN +#define CK_CONFIG_H_IN + +// clang-format off +// +// DataType supports in the current CK build +// +#ifndef DTYPES +#cmakedefine DTYPES "@DTYPES@" +#endif +// if DTYPES is not defined, enable all datatypes in headerfiles +#ifndef CK_ENABLE_ALL_DTYPES +#cmakedefine CK_ENABLE_ALL_DTYPES @CK_ENABLE_ALL_DTYPES@ +#if defined(CK_ENABLE_ALL_DTYPES) +#ifndef CK_ENABLE_INT8 +#define CK_ENABLE_INT8 "ON" +#endif +#ifndef CK_ENABLE_FP8 +#define CK_ENABLE_FP8 "ON" +#endif +#ifndef CK_ENABLE_BF8 +#define CK_ENABLE_BF8 "ON" +#endif +#ifndef CK_ENABLE_FP16 +#define CK_ENABLE_FP16 "ON" +#endif +#ifndef CK_ENABLE_BF16 +#define CK_ENABLE_BF16 "ON" +#endif +#ifndef CK_ENABLE_FP32 +#define CK_ENABLE_FP32 "ON" +#endif +#ifndef CK_ENABLE_FP64 +#define CK_ENABLE_FP64 "ON" +#endif +#endif +#endif +// if DTYPES are selectively enabled +#ifndef CK_ENABLE_INT8 +#cmakedefine CK_ENABLE_INT8 @CK_ENABLE_INT8@ +#endif + +#ifndef CK_ENABLE_FP8 +#cmakedefine CK_ENABLE_FP8 @CK_ENABLE_FP8@ +#endif + +#ifndef CK_ENABLE_BF8 +#cmakedefine CK_ENABLE_BF8 @CK_ENABLE_BF8@ +#endif + +#ifndef CK_ENABLE_FP16 +#cmakedefine CK_ENABLE_FP16 @CK_ENABLE_FP16@ +#endif + +#ifndef CK_ENABLE_BF16 +#cmakedefine CK_ENABLE_BF16 @CK_ENABLE_BF16@ +#endif + +#ifndef CK_ENABLE_FP32 +#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@ +#endif + +#ifndef CK_ENABLE_FP64 +#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@ +#endif + +// +// Legacy DL kernel supports in the current CK build +// by default DL kernels are turned OFF +// +#ifndef CK_ENABLE_DL_KERNELS +#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@ +#endif + +#ifndef CK_ENABLE_DPP_KERNELS +#cmakedefine CK_ENABLE_DPP_KERNELS @CK_ENABLE_DPP_KERNELS@ +#endif + +// +// CK kernels which support XDL (MI series) +// +#ifndef CK_USE_XDL +#cmakedefine CK_USE_XDL @CK_USE_XDL@ +#endif + +// +// CK Kernels which support WMMA (recent Navi series) +// +#ifndef CK_USE_WMMA +#cmakedefine CK_USE_WMMA @CK_USE_WMMA@ +#endif + +#ifndef CK_USE_WMMA_FP8 +#cmakedefine CK_USE_WMMA_FP8 @CK_USE_WMMA_FP8@ +#endif + +#ifndef CK_USE_GFX94 +#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@ +#endif + +#ifndef CK_USE_OCP_FP8 +#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@ +#endif + +#ifndef CK_USE_FNUZ_FP8 +#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@ +#endif + +#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH +#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@ +#endif + +#ifndef CK_USE_NATIVE_MX_SUPPORT +#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@ +#endif + +// clang-format on + +#endif // CK_CONFIG_H_IN diff --git a/csrc/rocm/ck_extensions/include/ck/filesystem.hpp b/csrc/rocm/ck_extensions/include/ck/filesystem.hpp new file mode 100755 index 000000000000..41ab5f4ddbf7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/filesystem.hpp @@ -0,0 +1,135 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef GUARD_CK_FILESYSTEM_HPP_ +#define GUARD_CK_FILESYSTEM_HPP_ + +#include +#include + +// clang-format off +#if defined(CPPCHECK) + #define CK_HAS_FILESYSTEM 1 + #define CK_HAS_FILESYSTEM_TS 1 +#elif defined(_WIN32) + #if _MSC_VER >= 1920 + #define CK_HAS_FILESYSTEM 1 + #define CK_HAS_FILESYSTEM_TS 0 + #elif _MSC_VER >= 1900 + #define CK_HAS_FILESYSTEM 0 + #define CK_HAS_FILESYSTEM_TS 1 + #else + #define CK_HAS_FILESYSTEM 0 + #define CK_HAS_FILESYSTEM_TS 0 + #endif +#elif defined(__has_include) + #if __has_include() && __cplusplus >= 201703L + #define CK_HAS_FILESYSTEM 1 + #else + #define CK_HAS_FILESYSTEM 0 + #endif + #if __has_include() && __cplusplus >= 201103L + #define CK_HAS_FILESYSTEM_TS 1 + #else + #define CK_HAS_FILESYSTEM_TS 0 + #endif +#else + #define CK_HAS_FILESYSTEM 0 + #define CK_HAS_FILESYSTEM_TS 0 +#endif +// clang-format on + +#if CK_HAS_FILESYSTEM +#include +#elif CK_HAS_FILESYSTEM_TS +#include +#else +#error "No filesystem include available" +#endif + +namespace CK { + +#if CK_HAS_FILESYSTEM +namespace fs = ::std::filesystem; +#elif CK_HAS_FILESYSTEM_TS +namespace fs = ::std::experimental::filesystem; +#endif + +} // namespace CK + +inline std::string operator+(const std::string_view s, const CK::fs::path& path) +{ + return path.string().insert(0, s); +} + +inline std::string operator+(const CK::fs::path& path, const std::string_view s) +{ + return path.string().append(s); +} + +#define FS_ENUM_PERMS_ALL fs::perms::all + +#if CK_HAS_FILESYSTEM_TS +#ifdef __linux__ +#include +namespace CK { +inline fs::path weakly_canonical(const fs::path& path) +{ + std::string result(PATH_MAX, '\0'); + std::string p{path.is_relative() ? (fs::current_path() / path).string() : path.string()}; + char* retval = realpath(p.c_str(), &result[0]); + return (retval == nullptr) ? path : fs::path{result}; +} +} // namespace CK +#else +#error "Not implmeneted!" +#endif +#else +namespace CK { +inline fs::path weakly_canonical(const fs::path& path) { return fs::weakly_canonical(path); } +} // namespace CK +#endif + +namespace CK { + +#ifdef _WIN32 +constexpr std::string_view executable_postfix{".exe"}; +constexpr std::string_view library_prefix{""}; +constexpr std::string_view dynamic_library_postfix{".dll"}; +constexpr std::string_view static_library_postfix{".lib"}; +constexpr std::string_view object_file_postfix{".obj"}; +#else +constexpr std::string_view executable_postfix{""}; +constexpr std::string_view library_prefix{"lib"}; +constexpr std::string_view dynamic_library_postfix{".so"}; +constexpr std::string_view static_library_postfix{".a"}; +constexpr std::string_view object_file_postfix{".o"}; +#endif + +inline fs::path make_executable_name(const fs::path& path) +{ + return path.parent_path() / (path.filename() + executable_postfix); +} + +inline fs::path make_dynamic_library_name(const fs::path& path) +{ + return path.parent_path() / (library_prefix + path.filename() + dynamic_library_postfix); +} + +inline fs::path make_object_file_name(const fs::path& path) +{ + return path.parent_path() / (path.filename() + object_file_postfix); +} + +inline fs::path make_static_library_name(const fs::path& path) +{ + return path.parent_path() / (library_prefix + path.filename() + static_library_postfix); +} + +struct FsPathHash +{ + std::size_t operator()(const fs::path& path) const { return fs::hash_value(path); } +}; +} // namespace CK + +#endif // GUARD_CK_FILESYSTEM_HPP_ diff --git a/csrc/rocm/ck_extensions/include/ck/host_utility/device_prop.hpp b/csrc/rocm/ck_extensions/include/ck/host_utility/device_prop.hpp new file mode 100755 index 000000000000..5439bbe1f077 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/host_utility/device_prop.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef __HIPCC_RTC__ +#include +#include +#include + +namespace ck { + +constexpr unsigned int fnv1a_hash(std::string_view str, unsigned int h = 2166136261u) +{ + return str.empty() ? h + : fnv1a_hash(str.substr(1), + (h ^ static_cast(str.front())) * 16777619u); +} +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + const auto name = raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. + switch(fnv1a_hash(name)) + { + // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 + case fnv1a_hash("Ellesmere"): + case fnv1a_hash("Baffin"): + case fnv1a_hash("RacerX"): + case fnv1a_hash("Polaris10"): + case fnv1a_hash("Polaris11"): + case fnv1a_hash("Tonga"): + case fnv1a_hash("Fiji"): + case fnv1a_hash("gfx800"): + case fnv1a_hash("gfx802"): + case fnv1a_hash("gfx804"): return "gfx803"; + case fnv1a_hash("Vega10"): + case fnv1a_hash("gfx901"): return "gfx900"; + case fnv1a_hash("10.3.0 Sienna_Cichlid 18"): return "gfx1030"; + default: return name; + } +} + +inline bool is_xdl_supported() +{ + return ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || + ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; +} + +inline bool is_lds_direct_load_supported() +{ + // Check if direct loads from global memory to LDS are supported. + return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx942" || + ck::get_device_name() == "gfx950"; +} + +inline bool is_bf16_atomic_supported() +{ + return ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"; +} + +inline bool is_gfx101_supported() +{ + return ck::get_device_name() == "gfx1010" || ck::get_device_name() == "gfx1011" || + ck::get_device_name() == "gfx1012"; +} + +inline bool is_gfx103_supported() +{ + return ck::get_device_name() == "gfx1030" || ck::get_device_name() == "gfx1031" || + ck::get_device_name() == "gfx1032" || ck::get_device_name() == "gfx1034" || + ck::get_device_name() == "gfx1035" || ck::get_device_name() == "gfx1036"; +} + +inline bool is_gfx11_supported() +{ + return ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || + ck::get_device_name() == "gfx1102" || ck::get_device_name() == "gfx1103" || + ck::get_device_name() == "gfx1150" || ck::get_device_name() == "gfx1151" || + ck::get_device_name() == "gfx1152"; +} + +inline bool is_gfx12_supported() +{ + return ck::get_device_name() == "gfx1200" || ck::get_device_name() == "gfx1201"; +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/host_utility/flush_cache.hpp b/csrc/rocm/ck_extensions/include/ck/host_utility/flush_cache.hpp new file mode 100755 index 000000000000..08b3aba2b3a3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/host_utility/flush_cache.hpp @@ -0,0 +1,378 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/flush_icache.hpp" +namespace ck { +namespace utility { + +template +struct RotatingMemWrapperMultiD +{ + static constexpr index_t NumDs = DsDataType::Size(); + + using ADataType = decltype(Argument::p_a_grid); + using BDataType = decltype(Argument::p_b_grid); + using DsGridPointer = decltype(Argument::p_ds_grid); + + RotatingMemWrapperMultiD() = delete; + RotatingMemWrapperMultiD(Argument& arg_, + std::size_t rotating_count_, + std::size_t size_a_, + std::size_t size_b_, + std::array size_ds_) + : arg(arg_), + rotating_count(rotating_count_), + size_a(size_a_), + size_b(size_b_), + size_ds(size_ds_) + { + p_a_grids.push_back(arg.p_a_grid); + p_b_grids.push_back(arg.p_b_grid); + p_ds_grids.push_back(arg.p_ds_grid); + for(size_t i = 1; i < rotating_count; i++) + { + { + void* pADeviceBuf; + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), + const_cast(p_a_grids[0]), + size_a_, + hipMemcpyDeviceToDevice)); + p_a_grids.push_back(pADeviceBuf); + } + + { + void* pBDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), + const_cast(p_b_grids[0]), + size_b_, + hipMemcpyDeviceToDevice)); + p_b_grids.push_back(pBDeviceBuf); + } + + { + + DsGridPointer ds_buffer; + static_for<0, NumDs, 1>{}([&](auto j) { + void* pDDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pDDeviceBuf), size_ds_[j])); + hip_check_error(hipMemcpy(static_cast(pDDeviceBuf), + static_cast(p_ds_grids[0][j]), + size_ds_[j], + hipMemcpyDeviceToDevice)); + + using DDataType = remove_cvref_t>; + + ds_buffer(j) = static_cast(pDDeviceBuf); + }); + + p_ds_grids.push_back(ds_buffer); + } + } + } + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + arg.p_a_grid = reinterpret_cast(p_a_grids[idx]); + arg.p_b_grid = reinterpret_cast(p_b_grids[idx]); + arg.p_ds_grid = p_ds_grids[idx]; + } + } + void Print() + { + std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b + << ", rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapperMultiD() + { + if(rotating_count > 1) + { + // restore ptr + arg.p_a_grid = reinterpret_cast(p_a_grids[0]); + arg.p_b_grid = reinterpret_cast(p_b_grids[0]); + arg.p_ds_grid = p_ds_grids[0]; + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + hip_check_error(hipFree(const_cast(p_a_grids[i]))); + hip_check_error(hipFree(const_cast(p_b_grids[i]))); + + static_for<0, NumDs, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + hip_check_error( + hipFree(static_cast(const_cast(p_ds_grids[i][j])))); + }); + } + } + } + + private: + Argument& arg; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::size_t size_a = 0; + std::size_t size_b = 0; + std::array size_ds = {0}; + std::vector p_a_grids; + std::vector p_b_grids; + std::vector p_ds_grids; +}; + +template +struct RotatingMemWrapper +{ + using ADataType = decltype(Argument::p_a_grid); + using BDataType = decltype(Argument::p_b_grid); + + RotatingMemWrapper() = delete; + RotatingMemWrapper(Argument& arg_, + std::size_t rotating_count_, + std::size_t size_a_, + std::size_t size_b_) + : arg(arg_), rotating_count(rotating_count_), size_a(size_a_), size_b(size_b_) + { + p_a_grids.push_back(arg.p_a_grid); + p_b_grids.push_back(arg.p_b_grid); + for(size_t i = 1; i < rotating_count; i++) + { + { + void* pADeviceBuf; + hip_check_error(hipMalloc(static_cast(&pADeviceBuf), size_a_)); + hip_check_error(hipMemcpy(static_cast(pADeviceBuf), + const_cast(p_a_grids[0]), + size_a_, + hipMemcpyDeviceToDevice)); + p_a_grids.push_back(pADeviceBuf); + } + + { + void* pBDeviceBuf; + hip_check_error(hipMalloc(static_cast(&pBDeviceBuf), size_b_)); + hip_check_error(hipMemcpy(static_cast(pBDeviceBuf), + const_cast(p_b_grids[0]), + size_b_, + hipMemcpyDeviceToDevice)); + p_b_grids.push_back(pBDeviceBuf); + } + } + } + + void Next() + { + if(rotating_count > 1) + { + std::size_t idx = iter++ % rotating_count; + arg.p_a_grid = reinterpret_cast(p_a_grids[idx]); + arg.p_b_grid = reinterpret_cast(p_b_grids[idx]); + } + } + void Print() + { + std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b + << ", rotating_count: " << rotating_count << "}" << std::endl; + } + ~RotatingMemWrapper() + { + if(rotating_count > 1) + { + // restore ptr + arg.p_a_grid = reinterpret_cast(p_a_grids[0]); + arg.p_b_grid = reinterpret_cast(p_b_grids[0]); + + // free device mem + for(size_t i = 1; i < rotating_count; i++) + { + hip_check_error(hipFree(const_cast(p_a_grids[i]))); + hip_check_error(hipFree(const_cast(p_b_grids[i]))); + } + } + } + + private: + Argument& arg; + std::size_t iter = 0; + std::size_t rotating_count = 1; + std::size_t size_a = 0; + std::size_t size_b = 0; + std::vector p_a_grids; + std::vector p_b_grids; +}; + +inline void flush_icache() +{ + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; + + ck::flush_icache<<>>(); + hip_check_error(hipGetLastError()); +} +// if TimePrePress == false, return time does not include preprocess's time +template +float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, + PreProcessFunc preprocess, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + GemmArgs& gemm_args, + Args... args) +{ +#if CK_TIME_KERNEL +#define MEDIAN 0 + if(stream_config.time_kernel_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } + // warm up + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + } + + const int nrepeat = stream_config.nrepeat_; + if(nrepeat == 0) + { + return 0.0; + } + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } + +#if MEDIAN + std::set times; +#else + float total_time = 0; +#endif + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + if constexpr(!TimePreprocess) + { + preprocess(); + } + + // hipEvent_t start, stop; + + // hip_check_error(hipEventCreate(&start)); + // hip_check_error(hipEventCreate(&stop)); + + // hip_check_error(hipDeviceSynchronize()); + // hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + // calculate preprocess time + if constexpr(TimePreprocess) + { + preprocess(); + } + // run real kernel + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + // end real kernel + + // hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + // hip_check_error(hipEventSynchronize(stop)); + // float cur_time = 0; + // hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); + // #if MEDIAN + // times.insert(cur_time); + // #else + // total_time += cur_time; + // #endif + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl; + + printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n", + static_cast(gemm_args.p_a_grid), + static_cast(gemm_args.p_b_grid)); + } + } + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + float cur_time = 0; + hip_check_error(hipEventElapsedTime(&cur_time, start, stop)); +#if MEDIAN + times.insert(cur_time); +#else + total_time += cur_time; +#endif + +#if MEDIAN + auto mid = times.begin(); + std::advance(mid, (nrepeat - 1) / 2); + if(nrepeat % 2 == 1) + { + return *mid; + } + else + { + auto mid_next = mid; + std::advance(mid_next, 1); + return (*mid + *mid_next) / 2; + } +#else + // return total_time / nrepeat; + hipDeviceProp_t deviceProps; + hip_check_error(hipGetDeviceProperties(&deviceProps, 0)); + float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01; + return (total_time - preprocess_offset * nrepeat) / nrepeat; +#endif + } + else + { + preprocess(); + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + + return 0; + } +#else + kernel<<>>(gemm_args, args...); + hip_check_error(hipGetLastError()); + + return 0; +#endif +} + +} // namespace utility +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/host_utility/hip_check_error.hpp b/csrc/rocm/ck_extensions/include/ck/host_utility/hip_check_error.hpp new file mode 100755 index 000000000000..0dfd2752695c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/host_utility/hip_check_error.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +// To be removed, which really does not tell the location of failed HIP functional call +inline void hip_check_error(hipError_t x) +{ + if(x != hipSuccess) + { + std::ostringstream ss; + ss << "HIP runtime error: " << hipGetErrorString(x) << ". " + << "hip_check_error.hpp" + << ": " << __LINE__ << "in function: " << __func__; + throw std::runtime_error(ss.str()); + } +} + +#define HIP_CHECK_ERROR(retval_or_funcall) \ + do \ + { \ + hipError_t _tmpVal = retval_or_funcall; \ + if(_tmpVal != hipSuccess) \ + { \ + std::ostringstream ostr; \ + ostr << "HIP Function Failed (" << __FILE__ << "," << __LINE__ << ") " \ + << hipGetErrorString(_tmpVal); \ + throw std::runtime_error(ostr.str()); \ + } \ + } while(0) diff --git a/csrc/rocm/ck_extensions/include/ck/host_utility/io.hpp b/csrc/rocm/ck_extensions/include/ck/host_utility/io.hpp new file mode 100755 index 000000000000..55734bab2e46 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/host_utility/io.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const std::array& v) +{ + std::copy(std::begin(v), std::end(v), std::ostream_iterator(os, " ")); + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const ck::TensorDescriptor& desc) +{ + constexpr ck::index_t nDim = ck::remove_cvref_t::GetNumOfDimension(); + + os << "{"; + + ck::static_for<0, nDim - 1, 1>{}([&](auto i) { os << desc.GetLength(i) << ", "; }); + + os << desc.GetLength(ck::Number{}); + + os << "}"; + + return os; +} diff --git a/csrc/rocm/ck_extensions/include/ck/host_utility/kernel_launch.hpp b/csrc/rocm/ck_extensions/include/ck/host_utility/kernel_launch.hpp new file mode 100755 index 000000000000..11a1c9bbc0d6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/host_utility/kernel_launch.hpp @@ -0,0 +1,170 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#ifndef __HIPCC_RTC__ +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +template +float launch_and_time_kernel(const StreamConfig& stream_config, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } + // warm up + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + const int nrepeat = stream_config.nrepeat_; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + + return total_time / nrepeat; + } + else + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; + } +#else + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; +#endif +} + +template +float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, + PreProcessFunc preprocess, + F kernel, + dim3 grid_dim, + dim3 block_dim, + std::size_t lds_byte, + Args... args) +{ +#if CK_TIME_KERNEL + if(stream_config.time_kernel_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n", + __func__, + grid_dim.x, + grid_dim.y, + grid_dim.z, + block_dim.x, + block_dim.y, + block_dim.z); + + printf("Warm up %d times\n", stream_config.cold_niters_); + } + // warm up + preprocess(); + for(int i = 0; i < stream_config.cold_niters_; ++i) + { + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + const int nrepeat = stream_config.nrepeat_; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("Start running %d times...\n", nrepeat); + } + hipEvent_t start, stop; + + hip_check_error(hipEventCreate(&start)); + hip_check_error(hipEventCreate(&stop)); + + hip_check_error(hipDeviceSynchronize()); + hip_check_error(hipEventRecord(start, stream_config.stream_id_)); + + for(int i = 0; i < nrepeat; ++i) + { + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + } + + hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); + hip_check_error(hipEventSynchronize(stop)); + + float total_time = 0; + + hip_check_error(hipEventElapsedTime(&total_time, start, stop)); + + hip_check_error(hipEventDestroy(start)); + hip_check_error(hipEventDestroy(stop)); + + return total_time / nrepeat; + } + else + { + preprocess(); + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; + } +#else + kernel<<>>(args...); + hip_check_error(hipGetLastError()); + + return 0; +#endif +} +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/host_utility/stream_utility.hpp b/csrc/rocm/ck_extensions/include/ck/host_utility/stream_utility.hpp new file mode 100755 index 000000000000..9ab49489bbe5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/host_utility/stream_utility.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/stream_config.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +static inline int getAvailableComputeUnitCount(const StreamConfig& stream_config) +{ + constexpr int MAX_MASK_DWORDS = 64; + + // assume at most 64*32 = 2048 CUs + uint32_t cuMask[MAX_MASK_DWORDS]; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + cuMask[i] = 0; + + auto countSetBits = [](uint32_t dword) { + int count = 0; + + while(dword != 0) + { + if(dword & 0x1) + count++; + + dword = dword >> 1; + }; + + return (count); + }; + + hip_check_error(hipExtStreamGetCUMask(stream_config.stream_id_, MAX_MASK_DWORDS, &cuMask[0])); + + int ret = 0; + + for(int i = 0; i < MAX_MASK_DWORDS; i++) + ret += countSetBits(cuMask[i]); + + return (ret); +}; diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/algorithm.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/algorithm.hpp new file mode 100755 index 000000000000..57136f8a2a19 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/algorithm.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +namespace ck { +namespace ranges { +template +auto copy(InputRange&& range, OutputIterator iter) + -> decltype(std::copy(std::begin(std::forward(range)), + std::end(std::forward(range)), + iter)) +{ + return std::copy(std::begin(std::forward(range)), + std::end(std::forward(range)), + iter); +} + +template +auto fill(OutputRange&& range, const T& init) + -> std::void_t(range)), + std::end(std::forward(range)), + init))> +{ + std::fill(std::begin(std::forward(range)), + std::end(std::forward(range)), + init); +} + +template +auto transform(InputRange&& range, OutputIterator iter, UnaryOperation unary_op) + -> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op)) +{ + return std::transform(std::begin(range), std::end(range), iter, unary_op); +} + +} // namespace ranges +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/check_err.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/check_err.hpp new file mode 100755 index 000000000000..d33ecaeef8b8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/check_err.hpp @@ -0,0 +1,505 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/type.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/library/utility/ranges.hpp" + +namespace ck { +namespace utils { + +template +double get_relative_threshold(const int number_of_accumulations = 1) +{ + using F4 = ck::f4_t; + using F8 = ck::f8_t; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + using I8 = int8_t; + using I32 = int32_t; + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, + "Warning: Unhandled ComputeDataType for setting up the relative threshold!"); + double compute_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + compute_error = std::pow(2, -NumericUtils::mant) * 0.5; + } + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, + "Warning: Unhandled OutDataType for setting up the relative threshold!"); + double output_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + output_error = std::pow(2, -NumericUtils::mant) * 0.5; + } + double midway_error = std::max(compute_error, output_error); + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, + "Warning: Unhandled AccDataType for setting up the relative threshold!"); + double acc_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + acc_error = std::pow(2, -NumericUtils::mant) * 0.5 * number_of_accumulations; + } + return std::max(acc_error, midway_error); +} + +template +double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) +{ + using F4 = ck::f4_t; + using F8 = ck::f8_t; + using F16 = ck::half_t; + using BF16 = ck::bhalf_t; + using F32 = float; + using I8 = int8_t; + using I32 = int32_t; + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, + "Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); + auto expo = std::log2(std::abs(max_possible_num)); + double compute_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + compute_error = std::pow(2, expo - NumericUtils::mant) * 0.5; + } + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, + "Warning: Unhandled OutDataType for setting up the absolute threshold!"); + double output_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + output_error = std::pow(2, expo - NumericUtils::mant) * 0.5; + } + double midway_error = std::max(compute_error, output_error); + + static_assert(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v, + "Warning: Unhandled AccDataType for setting up the absolute threshold!"); + double acc_error = 0; + if constexpr(is_same_v || is_same_v || + is_same_v) + { + return 0; + } + else + { + acc_error = + std::pow(2, expo - NumericUtils::mant) * 0.5 * number_of_accumulations; + } + return std::max(acc_error, midway_error); +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_floating_point_v> && + !std::is_same_v, half_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-5, + double atol = 3e-6) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = *std::next(std::begin(out), i); + const double r = *std::next(std::begin(ref), i); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_same_v, bhalf_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-1, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + // TODO: This is a hack. We should have proper specialization for bhalf_t data type. + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +typename std::enable_if< + std::is_same_v, ranges::range_value_t> && + std::is_same_v, half_t>, + bool>::type +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = NumericLimits>::Min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_integral_v> && + !std::is_same_v, bhalf_t> && + !std::is_same_v, f8_t> && + !std::is_same_v, bf8_t>) +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || std::is_same_v, int4_t> +#endif + , + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double = 0, + double atol = 0) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + int64_t err = 0; + int64_t max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const int64_t o = *std::next(std::begin(out), i); + const int64_t r = *std::next(std::begin(ref), i); + err = std::abs(o - r); + + if(err > atol) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r + << std::endl; + } + res = false; + } + } + if(!res) + { + const float error_percent = + static_cast(err_count) / static_cast(out.size()) * 100.f; + std::cerr << "max err: " << max_err; + std::cerr << ", number of errors: " << err_count; + std::cerr << ", " << error_percent << "% wrong values" << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err + << " number of errors: " << err_count << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, bf8_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 1e-3, + double atol = 1e-3) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; + } + return res; +} + +template +std::enable_if_t<(std::is_same_v, ranges::range_value_t> && + std::is_same_v, f4_t>), + bool> +check_err(const Range& out, + const RefRange& ref, + const std::string& msg = "Error: Incorrect results!", + double rtol = 0.5, + double atol = 0.5) +{ + if(out.size() != ref.size()) + { + std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() + << std::endl; + return false; + } + + bool res{true}; + int err_count = 0; + double err = 0; + double max_err = std::numeric_limits::min(); + + for(std::size_t i = 0; i < ref.size(); ++i) + { + const double o = type_convert(*std::next(std::begin(out), i)); + const double r = type_convert(*std::next(std::begin(ref), i)); + err = std::abs(o - r); + + if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) + { + max_err = err > max_err ? err : max_err; + err_count++; + if(err_count < 5) + { + std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i + << "] != ref[" << i << "]: " << o << " != " << r << std::endl; + } + res = false; + } + } + + if(!res) + { + std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err + << " number of errors: " << err_count << std::endl; + } + return res; +} + +} // namespace utils +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/conv_common.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/conv_common.hpp new file mode 100755 index 000000000000..085454f42dee --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/conv_common.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/tensor_descriptor.hpp" + +template +constexpr auto get_convolution_output_default_4d_tensor_descriptor( + const ck::TensorDescriptor& in_desc, + const ck::TensorDescriptor& wei_desc, + const ConvStrides& conv_strides, + const ConvDilations conv_dilations, + const LeftPads& left_pads, + const RightPads& right_pads) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + assert(in_desc.GetNumOfDimension() == 4); + assert(wei_desc.GetNumOfDimension() == 4); + assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1)); + + const auto N = in_desc.GetLength(I0); + const auto Hi = in_desc.GetLength(I2); + const auto Wi = in_desc.GetLength(I3); + + const auto K = wei_desc.GetLength(I0); + const auto Y = wei_desc.GetLength(I2); + const auto X = wei_desc.GetLength(I3); + + const auto LeftPadH = left_pads[I0]; + const auto LeftPadW = left_pads[I1]; + + const auto RightPadH = right_pads[I0]; + const auto RightPadW = right_pads[I1]; + + const auto YEff = (Y - I1) * conv_dilations[I0] + I1; + const auto XEff = (X - I1) * conv_dilations[I1] + I1; + + const auto Ho = (Hi + LeftPadH + RightPadH - YEff) / conv_strides[I0] + I1; + const auto Wo = (Wi + LeftPadW + RightPadW - XEff) / conv_strides[I1] + I1; + + return make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho, Wo)); +} + +template +constexpr std::size_t +calculate_convolution_flops(const InDesc&, const WeiDesc& wei_desc, const OutDesc& out_desc) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + const index_t N = out_desc.GetLength(I0); + const index_t K = out_desc.GetLength(I1); + const index_t Ho = out_desc.GetLength(I2); + const index_t Wo = out_desc.GetLength(I3); + + const index_t C = wei_desc.GetLength(I1); + const index_t Y = wei_desc.GetLength(I2); + const index_t X = wei_desc.GetLength(I3); + + return std::size_t(2) * N * K * Ho * Wo * C * Y * X; +} diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp new file mode 100755 index 000000000000..d4ceefb4589b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/convolution_host_tensor_descriptor_helper.hpp @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/utility/convolution_parameter.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace utils { +namespace conv { + +namespace detail { + +template +std::vector get_layout_transpose_gnchw_to_old() +{ + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 3, 2}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 4, 2, 3}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 5, 2, 3, 4}; + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 2, 3}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v) + { + return {1, 0, 2, 3}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v) + { + return {1, 0, 2, 3, 4}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v) + { + return {1, 0, 2, 3, 4, 5}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 2, 3, 4}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 2, 3, 4, 5}; + } + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 3, 2}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 4, 2, 3}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {0, 1, 5, 2, 3, 4}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {2, 0, 3, 1}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {3, 0, 4, 1, 2}; + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + return {4, 0, 5, 1, 2, 3}; + } + else + { + printf("%s\n", __func__); + throw std::runtime_error("wrong! unsupported layout"); + } +} + +} // namespace detail + +// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW +// regardless of physical layout +template +HostTensorDescriptor +make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck::utils::conv::ConvParam& param) +{ + std::vector physical_lengths; + + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.input_spatial_lengths_.begin(), + param.input_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", InLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX +// regardless of physical layout +template +HostTensorDescriptor +make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck::utils::conv::ConvParam& param) +{ + std::vector physical_lengths; + + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.end(), + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.K_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.K_), + static_cast(param.G_), + static_cast(param.C_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.filter_spatial_lengths_.begin(), + param.filter_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", WeiLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW +// regardless of physical layout +template +HostTensorDescriptor +make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck::utils::conv::ConvParam& param) +{ + std::vector physical_lengths; + + // HACK: NHWC/KYXC/NHWK, which is treated as GNHWC/GKYXC/GNHWK by this function, + // is used by some legacy kernel. New kernel should use GNHWK/GKYXC/GNHWK + // TODO: remove this branch after removing legacy kernel + if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + if(param.G_ != 1) + { + throw std::runtime_error("wrong! G != 1"); + } + + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.end(), + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + // separate from legacy code above + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.end(), + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.G_), + static_cast(param.N_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 2, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else if constexpr(ck::is_same_v || + ck::is_same_v || + ck::is_same_v) + { + physical_lengths = std::vector{static_cast(param.N_), + static_cast(param.G_), + static_cast(param.K_)}; + + physical_lengths.insert(physical_lengths.begin() + 1, + param.output_spatial_lengths_.begin(), + param.output_spatial_lengths_.begin() + param.num_dim_spatial_); + } + else + { + printf("%s\n", __func__); + printf("%s\n", OutLayout::name); + throw std::runtime_error("wrong! unsupported layout"); + } + + return transpose_host_tensor_descriptor_given_new2old( + HostTensorDescriptor(physical_lengths), + detail::get_layout_transpose_gnchw_to_old()); +} + +} // namespace conv +} // namespace utils +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/convolution_parameter.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/convolution_parameter.hpp new file mode 100755 index 000000000000..70d581a67e70 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/convolution_parameter.hpp @@ -0,0 +1,113 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" + +#include "ck/library/utility/numeric.hpp" + +namespace ck { +namespace utils { +namespace conv { + +struct ConvParam +{ + ConvParam(); + ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads); + + ConvParam(ck::long_index_t n_dim, + ck::long_index_t group_count, + ck::long_index_t n_batch, + ck::long_index_t n_out_channels, + ck::long_index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads); + + ck::long_index_t num_dim_spatial_; + ck::long_index_t G_; + ck::long_index_t N_; + ck::long_index_t K_; + ck::long_index_t C_; + + std::vector filter_spatial_lengths_; + std::vector input_spatial_lengths_; + std::vector output_spatial_lengths_; + + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + + std::vector input_left_pads_; + std::vector input_right_pads_; + + std::vector GetOutputSpatialLengths() const; + + std::size_t GetFlops() const; + + template + std::size_t GetInputByte() const + { + // sizeof(InDataType) * (G * N * C * ) + + return sizeof(InDataType) * + (G_ * N_ * C_ * + ck::accumulate_n( + std::begin(input_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>())); + } + + template + std::size_t GetWeightByte() const + { + // sizeof(WeiDataType) * (G * K * C * ) + + return sizeof(WeiDataType) * + (G_ * K_ * C_ * + ck::accumulate_n( + std::begin(filter_spatial_lengths_), num_dim_spatial_, 1, std::multiplies<>())); + } + + template + std::size_t GetOutputByte() const + { + // sizeof(OutDataType) * (G * N * K * ); + return sizeof(OutDataType) * (G_ * N_ * K_ * + std::accumulate(std::begin(output_spatial_lengths_), + std::end(output_spatial_lengths_), + static_cast(1), + std::multiplies())); + } + + template + std::size_t GetByte() const + { + return GetInputByte() + GetWeightByte() + + GetOutputByte(); + } +}; + +std::string get_conv_param_parser_helper_msg(); + +ConvParam parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[]); + +} // namespace conv +} // namespace utils +} // namespace ck + +std::ostream& operator<<(std::ostream& os, const ck::utils::conv::ConvParam& p); diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/device_memory.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/device_memory.hpp new file mode 100755 index 000000000000..d2e611a77cf6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/device_memory.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +template +__global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size) +{ + for(uint64_t i = threadIdx.x; i < buffer_element_size; i += blockDim.x) + { + p[i] = x; + } +} + +/** + * @brief Container for storing data in GPU device memory + * + */ +struct DeviceMem +{ + DeviceMem() : mpDeviceBuf(nullptr), mMemSize(0) {} + DeviceMem(std::size_t mem_size); + void Realloc(std::size_t mem_size); + void* GetDeviceBuffer() const; + std::size_t GetBufferSize() const; + void ToDevice(const void* p) const; + void ToDevice(const void* p, const std::size_t cpySize) const; + void FromDevice(void* p) const; + void FromDevice(void* p, const std::size_t cpySize) const; + void SetZero() const; + template + void SetValue(T x) const; + ~DeviceMem(); + + void* mpDeviceBuf; + std::size_t mMemSize; +}; + +template +void DeviceMem::SetValue(T x) const +{ + if(mMemSize % sizeof(T) != 0) + { + throw std::runtime_error("wrong! not entire DeviceMem will be set"); + } + + set_buffer_value<<<1, 1024>>>(static_cast(mpDeviceBuf), x, mMemSize / sizeof(T)); +} diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/fill.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/fill.hpp new file mode 100755 index 000000000000..4f421b4282f6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/fill.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace utils { + +template +struct FillUniformDistribution +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate(first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. +// However this produces segfaults in std::mt19937 which look like inifite loop. +// template +// struct FillUniformDistributionIntegerValue +// { +// int a_{-5}; +// int b_{5}; +// +// template +// void operator()(ForwardIter first, ForwardIter last) const +// { +// std::mt19937 gen(11939); +// std::uniform_int_distribution dis(a_, b_); +// std::generate( +// first, last, [&dis, &gen]() { return ck::type_convert(dis(gen)); }); +// } +// }; + +// Workaround for uniform_int_distribution not working as expected. See note above.< +template +struct FillUniformDistributionIntegerValue +{ + float a_{-5.f}; + float b_{5.f}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::mt19937 gen(11939); + std::uniform_real_distribution dis(a_, b_); + std::generate( + first, last, [&dis, &gen]() { return ck::type_convert(std::round(dis(gen))); }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +/** + * @brief A functor for filling a container with a monotonically increasing or decreasing sequence. + * + * FillMonotonicSeq generates a sequence of values starting from an initial value + * and incrementing by a fixed step for each subsequent element. + * + * @tparam T The numeric type of the sequence elements. + * + * Example usage: + * ``` + * std::vector v(5); + * FillMonotonicSeq{10, 2}(v); // Fills v with {10, 12, 14, 16, 18} + * ``` + */ +template +struct FillMonotonicSeq +{ + T init_value_{0}; + T step_{1}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::generate(first, last, [=, *this, n = init_value_]() mutable { + auto tmp = n; + n += step_; + return tmp; + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct FillConstant +{ + T value_{0}; + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::fill(first, last, value_); + } + + template + auto operator()(ForwardRange&& range) const -> std::void_t< + decltype(std::declval()(std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +template +struct TransformIntoStructuralSparsity +{ + // clang-format off + static constexpr T valid_sequences[] = { + 0, 0, 1, 1, + 0, 1, 0, 1, + 0, 1, 1, 0, + 1, 0, 0, 1, + 1, 0, 1, 0, + 1, 1, 0, 0, + }; + // clang-format on + + template + void operator()(ForwardIter first, ForwardIter last) const + { + std::for_each(first, last, [=, *this, idx = 0](T& elem) mutable { + auto tmp_idx = idx; + idx += 1; + return elem *= valid_sequences[tmp_idx % (sizeof(valid_sequences) / sizeof(T))]; + }); + } + + template + auto operator()(ForwardRange&& range) const + -> std::void_t()( + std::begin(std::forward(range)), + std::end(std::forward(range))))> + { + (*this)(std::begin(std::forward(range)), + std::end(std::forward(range))); + } +}; + +} // namespace utils +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/host_common_util.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/host_common_util.hpp new file mode 100755 index 000000000000..16d6e2233a27 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/host_common_util.hpp @@ -0,0 +1,136 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { + +namespace host_common { + +template +static inline void dumpBufferToFile(const char* fileName, T* data, size_t dataNumItems) +{ + std::ofstream outFile(fileName, std::ios::binary); + if(outFile) + { + outFile.write(reinterpret_cast(data), dataNumItems * sizeof(T)); + outFile.close(); + std::cout << "Write output to file " << fileName << std::endl; + } + else + { + std::cout << "Could not open file " << fileName << " for writing" << std::endl; + } +}; + +template +static inline T getSingleValueFromString(const std::string& valueStr) +{ + std::istringstream iss(valueStr); + + T val; + + iss >> val; + + return (val); +}; + +template +static inline std::vector getTypeValuesFromString(const char* cstr_values) +{ + std::string valuesStr(cstr_values); + + std::vector values; + std::size_t pos = 0; + std::size_t new_pos; + + new_pos = valuesStr.find(',', pos); + while(new_pos != std::string::npos) + { + const std::string sliceStr = valuesStr.substr(pos, new_pos - pos); + + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + pos = new_pos + 1; + new_pos = valuesStr.find(',', pos); + }; + + std::string sliceStr = valuesStr.substr(pos); + T val = getSingleValueFromString(sliceStr); + + values.push_back(val); + + return (values); +} + +template +static inline std::vector> +get_index_set(const std::array& dim_lengths) +{ + static_assert(NDim >= 1, "NDim >= 1 is required to use this function!"); + + if constexpr(NDim == 1) + { + std::vector> index_set; + + for(int i = 0; i < dim_lengths[0]; i++) + { + std::array index{i}; + + index_set.push_back(index); + }; + + return index_set; + } + else + { + std::vector> index_set; + std::array partial_dim_lengths; + + std::copy(dim_lengths.begin() + 1, dim_lengths.end(), partial_dim_lengths.begin()); + + std::vector> partial_index_set; + + partial_index_set = get_index_set(partial_dim_lengths); + + for(index_t i = 0; i < dim_lengths[0]; i++) + for(const auto& partial_index : partial_index_set) + { + std::array index; + + index[0] = i; + + std::copy(partial_index.begin(), partial_index.end(), index.begin() + 1); + + index_set.push_back(index); + }; + + return index_set; + }; +}; + +template +static inline size_t get_offset_from_index(const std::array& strides, + const std::array& index) +{ + size_t offset = 0; + + for(int i = 0; i < NDim; i++) + offset += index[i] * strides[i]; + + return (offset); +}; + +} // namespace host_common +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/host_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/host_gemm.hpp new file mode 100755 index 000000000000..5eb7e3b8c974 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/host_gemm.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "host_tensor.hpp" + +template +void host_gemm_mk_kn_mn(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) +{ + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a_m_k.mDesc.GetLengths()[1]; + + float v_acc = 0; + + for(int k = 0; k < K; ++k) + { + float v_a; + float v_b; + + a_element_op(v_a, static_cast(a_m_k(m, k))); + b_element_op(v_b, static_cast(b_k_n(k, n))); + + v_acc += v_a * v_b; + } + + float v_c; + + c_element_op(v_c, v_acc); + + c_m_n(m, n) = v_c; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, + c_m_n.mDesc.GetLengths()[0], + c_m_n.mDesc.GetLengths()[1])(std::thread::hardware_concurrency()); +} diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/host_tensor.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/host_tensor.hpp new file mode 100755 index 000000000000..33c918c99780 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/host_tensor.hpp @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck/utility/data_type.hpp" +#include "ck/utility/span.hpp" +#include "ck/utility/type_convert.hpp" + +#include "ck/library/utility/algorithm.hpp" +#include "ck/library/utility/ranges.hpp" +#include "ck/library/utility/thread.hpp" + +template +std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + os << v; + } + return os; +} + +template +std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) +{ + bool first = true; + for(auto&& v : range) + { + if(first) + first = false; + else + os << delim; + + using RangeType = ck::remove_cvref_t; + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v) + { + os << ck::type_convert(v); + } + else if constexpr(std::is_same_v || + std::is_same_v) + { + const auto packed_floats = ck::type_convert(v); + const ck::vector_type vector_of_floats{packed_floats}; + os << vector_of_floats.template AsType()[ck::Number<0>{}] << delim + << vector_of_floats.template AsType()[ck::Number<1>{}]; + } + else + { + os << static_cast(v); + } + } + return os; +} + +template +auto call_f_unpack_args_impl(F f, T args, std::index_sequence) +{ + return f(std::get(args)...); +} + +template +auto call_f_unpack_args(F f, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return call_f_unpack_args_impl(f, args, std::make_index_sequence{}); +} + +template +auto construct_f_unpack_args_impl(T args, std::index_sequence) +{ + return F(std::get(args)...); +} + +template +auto construct_f_unpack_args(F, T args) +{ + constexpr std::size_t N = std::tuple_size{}; + + return construct_f_unpack_args_impl(args, std::make_index_sequence{}); +} + +struct HostTensorDescriptor +{ + HostTensorDescriptor() = default; + + void CalculateStrides(); + + template >> + HostTensorDescriptor(const std::initializer_list& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + HostTensorDescriptor(const std::initializer_list& lens) + : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template , std::size_t> || + std::is_convertible_v, ck::long_index_t>>> + HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end()) + { + this->CalculateStrides(); + } + + template && + std::is_convertible_v>> + HostTensorDescriptor(const std::initializer_list& lens, + const std::initializer_list& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + HostTensorDescriptor(const std::initializer_list& lens, + const std::initializer_list& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + template , std::size_t> && + std::is_convertible_v, std::size_t>) || + (std::is_convertible_v, ck::long_index_t> && + std::is_convertible_v, ck::long_index_t>)>> + HostTensorDescriptor(const Lengths& lens, const Strides& strides) + : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end()) + { + } + + std::size_t GetNumOfDimension() const; + std::size_t GetElementSize() const; + std::size_t GetElementSpaceSize() const; + + const std::vector& GetLengths() const; + const std::vector& GetStrides() const; + + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + assert(sizeof...(Is) == this->GetNumOfDimension()); + std::initializer_list iss{static_cast(is)...}; + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + std::size_t GetOffsetFromMultiIndex(const std::vector& iss) const + { + return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0}); + } + + friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc); + + private: + std::vector mLens; + std::vector mStrides; +}; + +template +HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor& a, + const New2Old& new2old) +{ + std::vector new_lengths(a.GetNumOfDimension()); + std::vector new_strides(a.GetNumOfDimension()); + + for(std::size_t i = 0; i < a.GetNumOfDimension(); i++) + { + new_lengths[i] = a.GetLengths()[new2old[i]]; + new_strides[i] = a.GetStrides()[new2old[i]]; + } + + return HostTensorDescriptor(new_lengths, new_strides); +} + +struct joinable_thread : std::thread +{ + template + joinable_thread(Xs&&... xs) : std::thread(std::forward(xs)...) + { + } + + joinable_thread(joinable_thread&&) = default; + joinable_thread& operator=(joinable_thread&&) = default; + + ~joinable_thread() + { + if(this->joinable()) + this->join(); + } +}; + +template +struct ParallelTensorFunctor +{ + F mF; + static constexpr std::size_t NDIM = sizeof...(Xs); + std::array mLens; + std::array mStrides; + std::size_t mN1d; + + ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast(xs)...}) + { + mStrides.back() = 1; + std::partial_sum(mLens.rbegin(), + mLens.rend() - 1, + mStrides.rbegin() + 1, + std::multiplies()); + mN1d = mStrides[0] * mLens[0]; + } + + std::array GetNdIndices(std::size_t i) const + { + std::array indices; + + for(std::size_t idim = 0; idim < NDIM; ++idim) + { + indices[idim] = i / mStrides[idim]; + i -= indices[idim] * mStrides[idim]; + } + + return indices; + } + + void operator()(std::size_t num_thread = 1) const + { + std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread; + + std::vector threads(num_thread); + + for(std::size_t it = 0; it < num_thread; ++it) + { + std::size_t iw_begin = it * work_per_thread; + std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d); + + auto f = [=, *this] { + for(std::size_t iw = iw_begin; iw < iw_end; ++iw) + { + call_f_unpack_args(mF, GetNdIndices(iw)); + } + }; + threads[it] = joinable_thread(f); + } + } +}; + +template +auto make_ParallelTensorFunctor(F f, Xs... xs) +{ + return ParallelTensorFunctor(f, xs...); +} + +template +struct Tensor +{ + using Descriptor = HostTensorDescriptor; + using Data = std::vector; + + template + Tensor(std::initializer_list lens) : mDesc(lens), mData(GetElementSpaceSize()) + { + } + + template + Tensor(std::initializer_list lens, std::initializer_list strides) + : mDesc(lens, strides), mData(GetElementSpaceSize()) + { + } + + template + Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize()) + { + } + + template + Tensor(const Lengths& lens, const Strides& strides) + : mDesc(lens, strides), mData(GetElementSpaceSize()) + { + } + + Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {} + + template + Tensor CopyAsType() const + { + Tensor ret(mDesc); + + ck::ranges::transform( + mData, ret.mData.begin(), [](auto value) { return ck::type_convert(value); }); + + return ret; + } + + Tensor() = delete; + Tensor(const Tensor&) = default; + Tensor(Tensor&&) = default; + + ~Tensor() = default; + + Tensor& operator=(const Tensor&) = default; + Tensor& operator=(Tensor&&) = default; + + template + explicit Tensor(const Tensor& other) : Tensor(other.template CopyAsType()) + { + } + void savetxt(std::string file_name, std::string dtype = "float") + { + std::ofstream file(file_name); + + if(file.is_open()) + { + for(auto& itm : mData) + { + if(dtype == "float") + file << ck::type_convert(itm) << std::endl; + else if(dtype == "int") + file << ck::type_convert(itm) << std::endl; + else + // TODO: we didn't implement operator<< for all custom + // data types, here fall back to float in case compile error + file << ck::type_convert(itm) << std::endl; + } + file.close(); + } + else + { + // Print an error message to the standard error + // stream if the file cannot be opened. + throw std::runtime_error(std::string("unable to open file:") + file_name); + } + } + decltype(auto) GetLengths() const { return mDesc.GetLengths(); } + + decltype(auto) GetStrides() const { return mDesc.GetStrides(); } + + std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); } + + std::size_t GetElementSize() const { return mDesc.GetElementSize(); } + + std::size_t GetElementSpaceSize() const + { + if constexpr(ck::is_packed_type_v>) + { + return (mDesc.GetElementSpaceSize() + 1) / ck::packed_size_v>; + } + else + { + return mDesc.GetElementSpaceSize(); + } + } + + std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); } + + void SetZero() { ck::ranges::fill(mData, T{0}); } + + template + void ForEach_impl(F&& f, std::vector& idx, size_t rank) + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(F&& f) + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void ForEach_impl(const F&& f, std::vector& idx, size_t rank) const + { + if(rank == mDesc.GetNumOfDimension()) + { + f(*this, idx); + return; + } + // else + for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++) + { + idx[rank] = i; + ForEach_impl(std::forward(f), idx, rank + 1); + } + } + + template + void ForEach(const F&& f) const + { + std::vector idx(mDesc.GetNumOfDimension(), 0); + ForEach_impl(std::forward(f), idx, size_t(0)); + } + + template + void GenerateTensorValue(G g, std::size_t num_thread = 1) + { + switch(mDesc.GetNumOfDimension()) + { + case 1: { + auto f = [&](auto i) { (*this)(i) = g(i); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread); + break; + } + case 2: { + auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); }; + make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread); + break; + } + case 3: { + auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); }; + make_ParallelTensorFunctor( + f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread); + break; + } + case 4: { + auto f = [&](auto i0, auto i1, auto i2, auto i3) { + (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3])(num_thread); + break; + } + case 5: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) { + (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4])(num_thread); + break; + } + case 6: { + auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) { + (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4], + mDesc.GetLengths()[5])(num_thread); + break; + } + case 12: { + auto f = [&](auto i0, + auto i1, + auto i2, + auto i3, + auto i4, + auto i5, + auto i6, + auto i7, + auto i8, + auto i9, + auto i10, + auto i11) { + (*this)(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) = + g(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11); + }; + make_ParallelTensorFunctor(f, + mDesc.GetLengths()[0], + mDesc.GetLengths()[1], + mDesc.GetLengths()[2], + mDesc.GetLengths()[3], + mDesc.GetLengths()[4], + mDesc.GetLengths()[5], + mDesc.GetLengths()[6], + mDesc.GetLengths()[7], + mDesc.GetLengths()[8], + mDesc.GetLengths()[9], + mDesc.GetLengths()[10], + mDesc.GetLengths()[11])(num_thread); + break; + } + default: throw std::runtime_error("unspported dimension"); + } + } + + // Generate random values with multiple threads. Guaranteed to give the same sequence with any + // number of threads provided. + template , + typename Mapping = ck::identity, + typename Generator = std::minstd_rand> + void GenerateTensorDistr(Distribution dis = {0.f, 1.f}, + Mapping fn = {}, + const Generator g = Generator(0), // default seed 0 + std::size_t num_thread = -1) + { + using ck::math::integer_divide_ceil; + using ck::math::min; + if(num_thread == -1ULL) + num_thread = min(ck::get_available_cpu_cores(), 80U); // max 80 threads + // At least 2MB per thread + num_thread = min(num_thread, integer_divide_ceil(this->GetElementSpaceSize(), 0x200000)); + constexpr std::size_t BLOCK_BYTES = 64; + constexpr std::size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T); + + const std::size_t num_blocks = integer_divide_ceil(this->GetElementSpaceSize(), BLOCK_SIZE); + const std::size_t blocks_per_thread = integer_divide_ceil(num_blocks, num_thread); + + std::vector threads; + threads.reserve(num_thread - 1); + const auto dst = const_cast(this->mData.data()); + const auto element_space_size = this->GetElementSpaceSize(); + for(int it = num_thread - 1; it >= 0; --it) + { + std::size_t ib_begin = it * blocks_per_thread; + std::size_t ib_end = min(ib_begin + blocks_per_thread, num_blocks); + + auto job = [=]() { + auto g_ = g; // copy + auto dis_ = dis; // copy + g_.discard(ib_begin * BLOCK_SIZE * ck::packed_size_v); + auto t_fn = [&]() { + // As user can pass integer distribution in dis, we must ensure that the correct + // constructor/converter is called at all times. For f4/f6/f8 types, to ensure + // correct results, we convert from float to the target type. In these cases + // integer constructors are interpreted as direct initialization of the internal + // storage with binary values instead of treating integers as subset of floats. + if constexpr(ck::is_same_v || ck::is_same_v) + return ck::type_convert(static_cast(fn(dis_(g_)))); + else if constexpr(ck::packed_size_v == 1) + return ck::type_convert(fn(dis_(g_))); + else if constexpr(ck::is_same_v) + return ck::f4x2_pk_t{ck::type_convert( + ck::float2_t{ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_)))})}; + else if constexpr(ck::is_same_v || + ck::is_same_v) + { + return ck::type_convert( + ck::float32_t{ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_)))}); + } + else if constexpr(ck::is_same_v || + ck::is_same_v) + { + return ck::type_convert( + ck::float16_t{ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_))), + ck::type_convert(fn(dis_(g_)))}); + } + else + static_assert(false, "Unsupported packed size for T"); + }; + + std::size_t ib = ib_begin; + for(; ib < ib_end - 1; ++ib) + ck::static_for<0, BLOCK_SIZE, 1>{}([&](auto iw_) { + constexpr size_t iw = iw_.value; + dst[ib * BLOCK_SIZE + iw] = t_fn(); + }); + for(std::size_t iw = 0; iw < BLOCK_SIZE; ++iw) + if(ib * BLOCK_SIZE + iw < element_space_size) + dst[ib * BLOCK_SIZE + iw] = t_fn(); + }; + + if(it > 0) + threads.emplace_back(std::move(job)); + else + job(); // last job run in the main thread + } + for(auto& t : threads) + t.join(); + } + + template + std::size_t GetOffsetFromMultiIndex(Is... is) const + { + return mDesc.GetOffsetFromMultiIndex(is...) / ck::packed_size_v>; + } + + template + T& operator()(Is... is) + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; + } + + template + const T& operator()(Is... is) const + { + return mData[mDesc.GetOffsetFromMultiIndex(is...) / + ck::packed_size_v>]; + } + + T& operator()(const std::vector& idx) + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; + } + + const T& operator()(const std::vector& idx) const + { + return mData[mDesc.GetOffsetFromMultiIndex(idx) / ck::packed_size_v>]; + } + + typename Data::iterator begin() { return mData.begin(); } + + typename Data::iterator end() { return mData.end(); } + + typename Data::pointer data() { return mData.data(); } + + typename Data::const_iterator begin() const { return mData.begin(); } + + typename Data::const_iterator end() const { return mData.end(); } + + typename Data::const_pointer data() const { return mData.data(); } + + typename Data::size_type size() const { return mData.size(); } + + template + auto AsSpan() const + { + constexpr std::size_t FromSize = sizeof(T); + constexpr std::size_t ToSize = sizeof(U); + + using Element = std::add_const_t>; + return ck::span{reinterpret_cast(data()), size() * FromSize / ToSize}; + } + + template + auto AsSpan() + { + constexpr std::size_t FromSize = sizeof(T); + constexpr std::size_t ToSize = sizeof(U); + + using Element = std::remove_reference_t; + return ck::span{reinterpret_cast(data()), size() * FromSize / ToSize}; + } + + Descriptor mDesc; + Data mData; +}; diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/host_tensor_generator.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/host_tensor_generator.hpp new file mode 100755 index 000000000000..ab69412c1552 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/host_tensor_generator.hpp @@ -0,0 +1,669 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" + +template +struct GeneratorTensor_0 +{ + template + T operator()(Is...) + { + return T{0}; + } +}; + +template +struct GeneratorTensor_1 +{ + T value = 1; + + template + T operator()(Is...) + { + return value; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::half_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bhalf_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +#if defined CK_ENABLE_FP8 +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f8_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf8_t operator()(Is...) + { + return ck::type_convert(value); + } +}; +#endif + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4x2_pk_t operator()(Is...) + { + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{value, value})}; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(value), static_cast(i)); + }); + return r; + } +}; + +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + int8_t operator()(Is...) + { + return value; + } +}; + +template <> +struct GeneratorTensor_1 +{ + int8_t value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int t = value + 8; + ck::pk_i4_t r = ((t << 4) + t) & 0xff; + return r; + } +}; + +template <> +struct GeneratorTensor_1 +{ + float value = 1; + + template + ck::e8m0_bexp_t operator()(Is...) + { + return ck::type_convert(value); + } +}; + +template +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + T operator()(Is...) + { + return static_cast((std::rand() % (max_value - min_value)) + min_value); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + r.pack(ck::type_convert(tmp), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + int8_t operator()(Is...) + { + return (std::rand() % (max_value - min_value)) + min_value; + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::pk_i4_t operator()(Is...) + { + int hi = std::rand() % (max_value - min_value) + min_value + 8; + int lo = std::rand() % (max_value - min_value) + min_value + 8; + ck::pk_i4_t r = ((hi << 4) + lo) & 0xff; + return r; + } +}; + +#if defined CK_ENABLE_FP8 +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f8_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::bf8_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; +#endif + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = (std::rand() % (max_value - min_value)) + min_value; + return ck::type_convert(tmp); + } +}; + +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = (std::rand() % (max_value - min_value)) + min_value; + float tmp1 = (std::rand() % (max_value - min_value)) + min_value; + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{tmp0, tmp1})}; + } +}; + +template +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + T operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + return static_cast(min_value + tmp * (max_value - min_value)); + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bhalf_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + +#if defined CK_ENABLE_FP8 +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f8_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; +#endif + +#if defined CK_ENABLE_BF8 +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bf8_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; +#endif + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4_t operator()(Is...) + { + float tmp = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp = min_value + tmp * (max_value - min_value); + + return ck::type_convert(fp32_tmp); + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = float(std::rand()) / float(RAND_MAX); + float tmp1 = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp0 = min_value + tmp0 * (max_value - min_value); + float fp32_tmp1 = min_value + tmp1 * (max_value - min_value); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + float rnd = float(std::rand()) / float(RAND_MAX); + float fp32 = min_value + rnd * (max_value - min_value); + r.pack(ck::type_convert(fp32), static_cast(i)); + }); + + return r; + } +}; + +template +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + T operator()(Is...) + { + float tmp = distribution(generator); + + return ck::type_convert(tmp); + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f4x2_pk_t operator()(Is...) + { + float fp32_tmp0 = distribution(generator); + float fp32_tmp1 = distribution(generator); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::f6x32_pk_t operator()(Is...) + { + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + +template <> +struct GeneratorTensor_4 +{ + std::mt19937 generator; + std::normal_distribution distribution; + + GeneratorTensor_4(float mean, float stddev, unsigned int seed = 1) + : generator(seed), distribution(mean, stddev){}; + + template + ck::bf6x32_pk_t operator()(Is...) + { + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}([&](auto i) { + r.pack(ck::type_convert(distribution(generator)), + static_cast(i)); + }); + + return r; + } +}; + +struct GeneratorTensor_Checkboard +{ + template + float operator()(Ts... Xs) const + { + std::array dims = {static_cast(Xs)...}; + return std::accumulate(dims.begin(), + dims.end(), + true, + [](bool init, ck::index_t x) -> int { return init != (x % 2); }) + ? 1 + : -1; + } +}; + +/** + * @brief Is used to generate sequential values based on the specified dimension. + * + * @tparam T The type of the tensor values. + * @tparam Dim The specific dimension used for generation. + * + * GeneratorTensor_Sequential<1>{} will generate the following values for a 3x3 tensor: + * + * 0 1 2 + * 0 1 2 + * 0 1 2 + * + * Essentially, the values generated are logical coordinates of the generated element that + * correspond to dimension Dim. E.g. for 2-dimensional tensor and Dim=1, the values are the column + * indices. + * + */ +template +struct GeneratorTensor_Sequential +{ + template + T operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + return ck::type_convert(tmp); + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::f4x2_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + return ck::type_convert(ck::float2_t(tmp)); + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::f6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::f6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + +template +struct GeneratorTensor_Sequential +{ + template + ck::bf6x32_pk_t operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + + float tmp = dims[Dim]; + + ck::bf6x32_pk_t r; + ck::static_for<0, 32, 1>{}( + [&](auto i) { r.pack(ck::type_convert(tmp), static_cast(i)); }); + return r; + } +}; + +template +struct GeneratorTensor_Diagonal +{ + T value{1}; + + template + T operator()(Ts... Xs) const + { + std::array dims = {{static_cast(Xs)...}}; + size_t start_dim = dims.size() - NumEffectiveDim; + bool pred = true; + for(size_t i = start_dim + 1; i < dims.size(); i++) + { + pred &= (dims[start_dim] == dims[i]); + } + return pred ? value : T{0}; + } +}; diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/iterator.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/iterator.hpp new file mode 100755 index 000000000000..b44e2d8e3c4a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/iterator.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/type.hpp" + +namespace ck { + +template +using iter_value_t = typename std::iterator_traits>::value_type; + +template +using iter_reference_t = decltype(*std::declval()); + +template +using iter_difference_t = typename std::iterator_traits>::difference_type; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/literals.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/literals.hpp new file mode 100755 index 000000000000..a8bd6303f12f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/literals.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace literals { +// [P0330] Literal Suffix for (signed) size_t (C++23) +// ref: https://wg21.link/p0330r8 +inline constexpr std::size_t operator""_uz(unsigned long long size) +{ + return static_cast(size); +} + +inline constexpr std::size_t operator""_zu(unsigned long long size) +{ + return static_cast(size); +} +} // namespace literals +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/numeric.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/numeric.hpp new file mode 100755 index 000000000000..9ee118d47571 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/numeric.hpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +namespace ck { +template +auto accumulate_n(ForwardIterator first, Size count, T init, BinaryOperation op) + -> decltype(std::accumulate(first, std::next(first, count), init, op)) +{ + return std::accumulate(first, std::next(first, count), init, op); +} +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/ranges.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/ranges.hpp new file mode 100755 index 000000000000..f11e4204a0ce --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/ranges.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/library/utility/iterator.hpp" + +namespace ck { +namespace ranges { + +template +using iterator_t = decltype(std::begin(std::declval())); + +template +using sentinel_t = decltype(std::end(std::declval())); + +template +using range_size_t = decltype(std::size(std::declval())); + +template +using range_difference_t = ck::iter_difference_t>; + +template +using range_value_t = iter_value_t>; + +template +using range_reference_t = iter_reference_t>; + +template +struct is_range : std::false_type +{ +}; + +template +struct is_range< + T, + std::void_t())), decltype(std::end(std::declval()))>> + : std::true_type +{ +}; + +template +inline constexpr bool is_range_v = is_range::value; + +template +struct is_sized_range : std::false_type +{ +}; + +template +struct is_sized_range()))>> + : std::bool_constant> +{ +}; +} // namespace ranges +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/library/utility/thread.hpp b/csrc/rocm/ck_extensions/include/ck/library/utility/thread.hpp new file mode 100755 index 000000000000..483c58c46f69 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/library/utility/thread.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifdef __linux__ +#include +#endif +#include +namespace ck { +inline unsigned int get_available_cpu_cores() +{ +#if defined(__linux__) + cpu_set_t cpu_set; + if(sched_getaffinity(0, sizeof(cpu_set_t), &cpu_set) == 0) + { + unsigned int cpu_count = CPU_COUNT(&cpu_set); + if(cpu_count > 0) + return cpu_count; + } +#endif + // Fallback if sched_getaffinity unavailable or fails + return std::thread::hardware_concurrency(); +} +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp b/csrc/rocm/ck_extensions/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp new file mode 100755 index 000000000000..6b118e972e67 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/problem_transform/transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_FORWARD_CONVOLUTION3D_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Do * Ho * Wo +// GemmN = K +// GemmK = Z * Y * X * C +template +__host__ __device__ constexpr auto +transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + const TensorDescriptor& in_grid_desc_n_di_hi_wi_c, + const TensorDescriptor& wei_k_z_y_x_c_grid_desc, + const TensorDescriptor& out_n_do_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_grid_desc_n_di_hi_wi_c.GetLength(I0); + const auto K = out_n_do_ho_wo_k_grid_desc.GetLength(I4); + const auto C = in_grid_desc_n_di_hi_wi_c.GetLength(I4); + + const auto Di = in_grid_desc_n_di_hi_wi_c.GetLength(I1); + const auto Hi = in_grid_desc_n_di_hi_wi_c.GetLength(I2); + const auto Wi = in_grid_desc_n_di_hi_wi_c.GetLength(I3); + + const auto Do = out_n_do_ho_wo_k_grid_desc.GetLength(I1); + const auto Ho = out_n_do_ho_wo_k_grid_desc.GetLength(I2); + const auto Wo = out_n_do_ho_wo_k_grid_desc.GetLength(I3); + + const auto Z = wei_k_z_y_x_c_grid_desc.GetLength(I1); + const auto Y = wei_k_z_y_x_c_grid_desc.GetLength(I2); + const auto X = wei_k_z_y_x_c_grid_desc.GetLength(I3); + + const auto ConvStrideD = conv_strides[I0]; + const auto ConvStrideH = conv_strides[I1]; + const auto ConvStrideW = conv_strides[I2]; + + const auto ConvDilationD = conv_dilations[I0]; + const auto ConvDilationH = conv_dilations[I1]; + const auto ConvDilationW = conv_dilations[I2]; + + const auto InLeftPadD = in_left_pads[I0]; + const auto InLeftPadH = in_left_pads[I1]; + const auto InLeftPadW = in_left_pads[I2]; + + const auto InRightPadD = in_right_pads[I0]; + const auto InRightPadH = in_right_pads[I1]; + const auto InRightPadW = in_right_pads[I2]; + + const auto GemmM = N * Do * Ho * Wo; + const auto GemmN = K; + const auto GemmK = Z * Y * X * C; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_di_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_dip_hip_wip_c, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{}, Sequence<7>{})); + + const auto in_grid_desc_gemmk_gemmm = + transform_tensor_descriptor(in_grid_desc_n_z_do_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_gemmk0_gemmm_gemmk1 = + transform_tensor_descriptor(in_grid_desc_gemmk_gemmm, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_grid_desc_gemmk_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Z * Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Z * Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_grid_desc_gemmk0_gemmn_gemmk1 = + transform_tensor_descriptor(wei_grid_desc_gemmk_gemmn, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto out_grid_desc_gemmm_gemmn = transform_tensor_descriptor( + // out_n_do_ho_wo_k_grid_desc, + // make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo)), + // make_pass_through_transform(K)), + // make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<3>{}), + // make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_grid_desc_gemmk0_gemmm_gemmk1, + wei_grid_desc_gemmk0_gemmn_gemmk1, + out_grid_desc_gemmm_gemmn); +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/stream_config.hpp b/csrc/rocm/ck_extensions/include/ck/stream_config.hpp new file mode 100755 index 000000000000..37ba250cf5d5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/stream_config.hpp @@ -0,0 +1,19 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +struct StreamConfig +{ + hipStream_t stream_id_ = nullptr; + bool time_kernel_ = false; + int log_level_ = 0; + int cold_niters_ = 5; + int nrepeat_ = 50; + + bool flush_cache = false; + int rotating_count = 1; +}; diff --git a/csrc/rocm/ck_extensions/include/ck/tensor/static_tensor.hpp b/csrc/rocm/ck_extensions/include/ck/tensor/static_tensor.hpp new file mode 100755 index 000000000000..ef2bedd65cef --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor/static_tensor.hpp @@ -0,0 +1,273 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_STATIC_TENSOR_HPP +#define CK_STATIC_TENSOR_HPP + +namespace ck { + +// StaticTensor for Scalar +template ::type = false> +struct StaticTensor +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {} + + __host__ __device__ constexpr StaticTensor(T invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // read access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const T& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // write access + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr T& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + StaticBuffer data_; + static constexpr T zero_scalar_value_ = T{0}; + const T invalid_element_scalar_value_; + T ignored_element_scalar_; +}; + +// StaticTensor for vector +template ::type = false> +struct StaticTensorTupleOfVectorBuffer +{ + static constexpr auto desc_ = TensorDesc{}; + static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension(); + static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize(); + + static constexpr index_t num_of_vector_ = + math::integer_divide_ceil(element_space_size_, ScalarPerVector); + + using V = vector_type; + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer() + : invalid_element_scalar_value_{0} + { + } + + __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value) + : invalid_element_scalar_value_{invalid_element_value} + { + } + + // Get S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr const S& operator[](Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_[Number{}]; + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return zero_scalar_value_; + } + else + { + return invalid_element_scalar_value_; + } + } + } + + // Set S + // Idx is for S, not V + template ::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr S& operator()(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_(Number{}); + } + else + { + return ignored_element_scalar_; + } + } + + // Get X + // Idx is for S, not X. Idx should be aligned with X + template ::value || !is_native_type()) && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr X GetAsType(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + return data_.template GetAsType(Number{}); + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + // TODO: is this right way to initialize a vector? + return X{0}; + } + else + { + // TODO: is this right way to initialize a vector? + return X{invalid_element_scalar_value_}; + } + } + } + + // Set X + // Idx is for S, not X. Idx should be aligned with X + template ::value || !is_native_type()) && + is_known_at_compile_time::value && Idx::Size() == ndim_, + bool>::type = false> + __host__ __device__ constexpr void SetAsType(Idx, X x) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord); + + if constexpr(is_valid) + { + data_.template SetAsType(Number{}, x); + } + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + // Get read access to V. No is_valid check + // Idx is for S, not V. Idx should be aligned with V + template + __host__ __device__ constexpr V& GetVectorTypeReference(Idx) + { + constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{})); + + constexpr index_t offset = coord.GetOffset(); + + return data_.GetVectorTypeReference(Number{}); + } + + StaticBufferTupleOfVector data_; + static constexpr S zero_scalar_value_ = S{0}; + const S invalid_element_scalar_value_ = S{0}; + S ignored_element_scalar_; +}; + +template ::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc) +{ + return StaticTensor{}; +} + +template < + AddressSpaceEnum AddressSpace, + typename T, + typename TensorDesc, + typename X, + typename enable_if::type = false, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value) +{ + return StaticTensor{invalid_element_value}; +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_description/cluster_descriptor.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_description/cluster_descriptor.hpp new file mode 100755 index 000000000000..2dfcad8e042e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_description/cluster_descriptor.hpp @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template ::type> +__host__ __device__ constexpr auto make_cluster_descriptor( + const Lengths& lengths, + ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{}) +{ + constexpr index_t ndim_low = Lengths::Size(); + + const auto reordered_lengths = container_reorder_given_new2old(lengths, order); + + const auto low_lengths = generate_tuple( + [&](auto idim_low) { return reordered_lengths[idim_low]; }, Number{}); + + const auto transform = make_merge_transform(low_lengths); + + constexpr auto low_dim_old_top_ids = ArrangeOrder{}; + + constexpr auto up_dim_new_top_ids = Sequence<0>{}; + + return make_single_stage_tensor_adaptor( + make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids)); +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_description/multi_index_transform.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_description/multi_index_transform.hpp new file mode 100755 index 000000000000..c152cbfb1ea1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_description/multi_index_transform.hpp @@ -0,0 +1,2046 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/multi_index.hpp" + +namespace ck { + +template +struct PassThrough +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr PassThrough() = default; + + __host__ __device__ constexpr PassThrough(const LowLength& low_length) + : up_lengths_{make_tuple(low_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("PassThrough, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct Pad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + RightPadLength right_pad_length_; + + __host__ __device__ constexpr Pad() = default; + + __host__ __device__ constexpr Pad(const LowLength& low_length, + const LeftPadLength& left_pad_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)}, + left_pad_length_{left_pad_length}, + right_pad_length_{right_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || + ((idx_up[Number<0>{}] >= left_pad_length_) && + (idx_up[Number<0>{}] < up_lengths_[Number<0>{}] - right_pad_length_)); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Pad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_length %d", index_t{left_pad_length_}); + printf("right_pad_length %d", index_t{right_pad_length_}); + printf("}"); + } +}; + +template +struct LeftPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{})); + + UpLengths up_lengths_; + LeftPadLength left_pad_length_; + + __host__ __device__ constexpr LeftPad() = default; + + __host__ __device__ constexpr LeftPad(const LowLength& low_length, + const LeftPadLength& left_pad_length) + : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] - left_pad_length_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] >= left_pad_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("LeftPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("left_pad_length_ %d", index_t{left_pad_length_}); + printf("}"); + } +}; + +template +struct RightPad +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{})); + + UpLengths up_lengths_; + LowLength low_length_; + RightPadLength right_pad_length_; + + __host__ __device__ constexpr RightPad() = default; + + __host__ __device__ constexpr RightPad(const LowLength& low_length, + const RightPadLength& right_pad_length) + : up_lengths_{make_tuple(low_length + right_pad_length)}, + low_length_{low_length}, + right_pad_length_{right_pad_length} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ static constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return SkipIsValidCheck; + } + + template + __host__ __device__ constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const + { + return SkipIsValidCheck || (idx_up[Number<0>{}] < low_length_); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("RightPad, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("low_length_ %d", index_t{low_length_}); + printf("left_pad_length_ %d", index_t{right_pad_length_}); + printf("}"); + } +}; + +// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1] +// UpLengths and Coefficients can be either of the followings: +// 1) Tuple of index_t, which is known at run-time, or +// 2) Tuple of Number, which is known at compile-time, or +// 3) Tuple of mixture of index_t and Number, which is known partially at run-time and partially +// at compile-time +template ::type = false> +struct Embed +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + UpLengths up_lengths_; + Coefficients coefficients_; + + __host__ __device__ constexpr Embed() = default; + + __host__ __device__ constexpr Embed(const UpLengths& up_lengths, + const Coefficients& coefficients) + : up_lengths_{up_lengths}, coefficients_{coefficients} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) { + idx_low(Number<0>{}) += idx_up[i] * this->coefficients_[i]; + }); + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp && + LowIdx::Size() == 1 && UpIdx::Size() == NDimUp, + "wrong! inconsistent # of dimension"); + + idx_diff_low(Number<0>{}) = 0; + + static_for<0, NDimUp, 1>{}( + [&](auto i) { idx_diff_low(Number<0>{}) += idx_diff_up[i] * coefficients_[i]; }); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Embed, "); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("coefficients_ "); + print_multi_index(coefficients_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses regular to do lowering of +// multi-index and use carry-and-borrow check to do lowering of multi-index delta +template +struct Merge_v1_carry_check +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v1_carry_check() = default; + + __host__ __device__ constexpr Merge_v1_carry_check(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // normal division + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex_1a(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] + idx_diff_low_const[i]); + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] - borrow; + + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_1b(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at + // run-time each time this function is called, and can be very expensive. + LowerIndex idx_diff_low_const; + LowerIndex idx_low_length_minus_idx_diff_low_const; + LowerIndex idx_low_length_plus_idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = low_lengths_[i] - idx_diff_low_const[i]; + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); + + static_for<0, NDimLow, 1>{}([&](auto i) { + idx_low_length_minus_idx_diff_low_const(i) = + __builtin_amdgcn_readfirstlane(low_lengths_[i] - idx_diff_low_const[i]); + + idx_low_length_plus_idx_diff_low_const(i) = low_lengths_[i] + idx_diff_low_const[i]; + }); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + else if constexpr(Hack == 2) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t borrow = 0; + + static_for{}([&](auto i) { + index_t negative_idx_low_tmp = borrow - idx_low[i]; + + bool do_borrow = negative_idx_low_tmp > idx_diff_low_const[i]; + + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low_const[i]; + + idx_diff_low(i) -= borrow; + + borrow = do_borrow ? 1 : 0; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] - borrow; + + idx_low += idx_diff_low; + } + else + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + index_t carry = 0; + + static_for{}([&](auto i) { + index_t idx_low_tmp = idx_low[i] + carry; + + bool do_carry = idx_low_tmp >= idx_low_length_minus_idx_diff_low_const[i]; + bool do_borrow = idx_low_tmp < -idx_diff_low_const[i]; + + idx_diff_low(i) = + do_carry ? -idx_low_length_minus_idx_diff_low_const[i] : idx_diff_low_const[i]; + idx_diff_low(i) = + do_borrow ? idx_low_length_plus_idx_diff_low_const[i] : idx_diff_low[i]; + + idx_diff_low(i) += carry; + + carry = do_carry ? 1 : 0; + carry = do_borrow ? -1 : carry; + }); + + idx_diff_low(Number<0>{}) = idx_diff_low_const[Number<0>{}] + carry; + + idx_low += idx_diff_low; + } + } + + template + __host__ __device__ void UpdateLowerIndex_2(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& /* idx_up_new */, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + // CalculateLowerIndex(idx_diff_low_const) has multiple integer divisions. + // However, + // 1) If idx_diff_up is known at compile-time, then idx_diff_low_const + // can be calculated at compile-time. + // 2) If idx_diff_up is not known at compile-time, but its value + // doesn't change during the whole kernel execution, then + // idx_diff_low_const also + // doesn't change during the whole kernel execution. Compiler generated + // ISA should + // only caclculate idx_diff_low_const once and save it durinng the whole + // kernel execution + // If neither 1) nor 2) is satisfied, then the calculation will also be + // computed at run-time each time this function is called, and can be + // very expensive. + LowerIndex idx_diff_low_const; + +#if !CK_HACK_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = tmp / low_lengths_scan_[i]; + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = tmp; +#else + // Hack: this force result into SGPR. Need to make sure the result is thread invariant + index_t tmp = idx_diff_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_diff_low_const(i) = __builtin_amdgcn_readfirstlane(tmp / low_lengths_scan_[i]); + tmp -= idx_diff_low_const[i] * low_lengths_scan_[i]; + }); + + idx_diff_low_const(Number{}) = __builtin_amdgcn_readfirstlane(tmp); +#endif + + if constexpr(Hack == 1) + { + // do carry check on each low dimension in reversed order + // do not need to check the first dimension + bool do_carry = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] + do_carry; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_carry = idx_low_tmp >= low_lengths_[i]; + +#if 0 + // TODO: use exec-mask inline asm, which use 1 VALU + if(do_carry) + { + idx_diff_low(i) -= low_lengths_[i]; + } +#elif 1 + // this use 2 VALU + idx_diff_low(i) = do_carry ? idx_diff_low[i] - low_lengths_[i] : idx_diff_low[i]; +#elif 1 + // this use 2 VALU + index_t idx_diff_low_tmp = idx_diff_low[i] - low_lengths_[i]; + idx_diff_low(i) = do_carry ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] + do_carry; + + idx_low(I0) += idx_diff_low[I0]; + } + else if constexpr(Hack == 2) + { + // do borrow check on each low dimension in reversed order + // do not need to check the first dimension + bool do_borrow = 0; + + static_for{}([&](auto i) { + idx_diff_low(i) = idx_diff_low_const[i] - do_borrow; + + index_t idx_low_tmp = idx_low[i] + idx_diff_low[i]; + + do_borrow = idx_low_tmp < 0; + +#if 0 + // TODO: use exec-mask inline asm + if(do_borrow) + { + idx_diff_low(i) += low_lengths_[i]; + } +#elif 1 + idx_diff_low(i) = do_borrow ? idx_diff_low[i] + low_lengths_[i] : idx_diff_low[i]; +#elif 1 + index_t idx_diff_low_tmp = idx_diff_low[i] + low_lengths_[i]; + idx_diff_low(i) = do_borrow ? idx_diff_low_tmp : idx_diff_low[i]; +#endif + + idx_low(i) += idx_diff_low[i]; + }); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_low_const[I0] - do_borrow; + + idx_low(I0) += idx_diff_low[I0]; + } + else + { + // not implemented + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { +#if 1 + UpdateLowerIndex_1a(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#elif 0 + UpdateLowerIndex_1b(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#else + UpdateLowerIndex_2(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); +#endif + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v1_carry_check, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_multiplier +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicMultiplier(LowLengths{}[i]); + } +}; + +template +struct lambda_merge_generate_MagicDivision_calculate_magic_shift +{ + template + __host__ __device__ constexpr auto operator()(Number i) const + { + return MagicDivision::CalculateMagicShift(LowLengths{}[i]); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct Merge_v2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + using LowLengthsMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsMagicDivisorShift = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsMagicDivisorMultipiler low_lengths_magic_divisor_multiplier_; + LowLengthsMagicDivisorShift low_lengths_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v2_magic_division() = default; + + __host__ __device__ constexpr Merge_v2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths[i]); }, + Number{})}, + low_lengths_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + }); + + idx_low(Number<0>{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for{}([&, this](auto i) { + index_t tmp2 = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_magic_divisor_multiplier_[i], + this->low_lengths_magic_divisor_shift_[i]); + + index_t idx_low_old = idx_low[i]; + + idx_low(i) = tmp - tmp2 * this->low_lengths_[i]; + tmp = tmp2; + + idx_diff_low(i) = idx_low[i] - idx_low_old; + }); + + idx_diff_low(Number<0>{}) = tmp - idx_low(Number<0>{}); + + idx_low(Number<0>{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_magic_divisor_multiplier_); + printf("low_lengths_magic_divisor_shift_ "); + print_multi_index(low_lengths_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses magic-number-division to do lowering +// of both multi-index and delta of multi-index +// Caution: +// 1. The magic number division implementation being used would produce correct result if the +// dividended is uint32_t and its value is with in 31-bit value range of uint32_t. +// 2. The magic number division for int32_t dividened has not been implemented, the int32_t +// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for +// uint32_t is then used. +// 3. For Merge primitive, upper-index is the dividend. +// 4. When upper-index is uint32_t, its value need to be within 31-bit range. +// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be +// non-negative. +template +struct Merge_v2r2_magic_division +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + using LowLengthsScanMagicDivisorMultipiler = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_multiplier{}, + Number{})); + + using LowLengthsScanMagicDivisorShift = decltype(generate_tuple( + lambda_merge_generate_MagicDivision_calculate_magic_shift{}, + Number{})); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + LowLengthsScanMagicDivisorMultipiler low_lengths_scan_magic_divisor_multiplier_; + LowLengthsScanMagicDivisorShift low_lengths_scan_magic_divisor_shift_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v2r2_magic_division() = default; + + __host__ __device__ constexpr Merge_v2r2_magic_division(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + low_lengths_scan_magic_divisor_multiplier_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicMultiplier(low_lengths_scan_[i]); }, + Number{})}, + low_lengths_scan_magic_divisor_shift_{generate_tuple( + [&](auto i) { return MagicDivision::CalculateMagicShift(low_lengths_scan_[i]); }, + Number{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up_new[Number<0>{}]; + + static_for<0, NDimLow - 1, 1>{}([&, this](auto i) { + index_t idx_low_old = idx_low[i]; + + idx_low(i) = + MagicDivision::DoMagicDivision(tmp, + this->low_lengths_scan_magic_divisor_multiplier_[i], + this->low_lengths_scan_magic_divisor_shift_[i]); + + idx_diff_low(i) = idx_low[i] - idx_low_old; + + tmp -= idx_low[i] * this->low_lengths_scan_[i]; + }); + + idx_diff_low(Number{}) = tmp - idx_low[Number{}]; + + idx_low(Number{}) = tmp; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v2r2_magic_division, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan "); + print_multi_index(low_lengths_scan_); + printf("low_lengths_scan_magic_divisor_multiplier_ "); + print_multi_index(low_lengths_scan_magic_divisor_multiplier_); + printf("low_lengths_scan_magic_divisor_shift_ "); + print_multi_index(low_lengths_scan_magic_divisor_shift_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v3_division_mod +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v3_division_mod() = default; + + __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + const index_t tmp2 = idx_low[i]; + idx_low(i) = tmp / this->low_lengths_scan_[i]; + idx_diff_low(i) = idx_low[i] - tmp2; + tmp %= this->low_lengths_scan_[i]; + }); + + const index_t tmp2 = idx_low[INm1]; + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_low[INm1] - tmp2; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct UnMerge +{ + static constexpr index_t NDimUp = UpLengths::Size(); + + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex; + + using UpLengthsScan = + decltype(container_reverse_exclusive_scan(UpLengths{}, math::multiplies{}, Number<1>{})); + + UpLengths up_lengths_; + UpLengthsScan up_lengths_scan_; + + __host__ __device__ constexpr UnMerge() = default; + + __host__ __device__ constexpr UnMerge(const UpLengths& up_lengths) + : up_lengths_{up_lengths}, + up_lengths_scan_{ + container_reverse_exclusive_scan(up_lengths, math::multiplies{}, Number<1>{})} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return NDimUp; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + if constexpr(!Use24BitIntegerCalculation) + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}( + [&](auto i) { idx_low(Number<0>{}) += idx_up[i] * up_lengths_scan_[i]; }); + } + else + { + idx_low(Number<0>{}) = idx_up[Number{}]; + + static_for<0, NDimUp - 1, 1>{}([&](auto i) { + idx_low(Number<0>{}) = + (0x00ffffff & idx_low[Number<0>{}]) + + (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]); + }); + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + CalculateLowerIndex(idx_diff_low, idx_diff_up); + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("UnMerge, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("up_lengths_scan_"); + print_multi_index(up_lengths_scan_); + printf("}"); + } +}; + +template +struct Freeze +{ + LowerIndex low_idx_; + + __host__ __device__ constexpr Freeze() = default; + + __host__ __device__ constexpr Freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 0; } + + __host__ __device__ static constexpr auto GetUpperLengths() { return Tuple<>{}; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& /* idx_up */) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 0, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = low_idx_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& /* idx_diff_up */, + LowIdx& /* idx_low */, + const UpIdx& /* idx_up_new */, + Number) + { + idx_diff_low(Number<0>{}) = 0; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Freeze"); + printf("low_idx_ %d", index_t{low_idx_}); + } +}; + +// Insert a dangling upper dimension without lower dimension +template +struct Insert +{ + using UpLengths = decltype(make_tuple(UpperLength{})); + + UpLengths up_lengths_; + + __host__ __device__ constexpr Insert() = default; + + __host__ __device__ constexpr Insert(const UpperLength& up_length) + : up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 0; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr auto GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx&, const UpIdx&) const + { + static_assert(LowIdx::Size() == 0 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + template + __host__ __device__ static void + UpdateLowerIndex(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&, Number) + { + static_assert(LowIdxDiff::Size() == 0 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 0 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Insert"); + print_multi_index(up_lengths_); + } +}; + +template +struct Vectorize +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(UpLength{})); + + UpLengths up_lengths_; + VectorSize vector_size_; + + __host__ __device__ constexpr Vectorize() = default; + + __host__ __device__ constexpr Vectorize(const VectorSize& vector_size, + const UpLength& up_length) + : vector_size_{vector_size}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = vector_size_ * idx_up[Number<0>{}]; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = vector_size_ * idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Vectorize, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct Slice +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + + using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{})); + + UpLengths up_lengths_; + SliceBegin slice_begin_; + SliceEnd slice_end_; + + __host__ __device__ constexpr Slice() = default; + + __host__ __device__ constexpr Slice(const LowLength&, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) + : up_lengths_{make_tuple(slice_end - slice_begin)}, + slice_begin_{slice_begin}, + slice_end_{slice_end} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] + slice_begin_; + } + + template + __host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx&, + Number) + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + idx_diff_low(I0) = idx_diff_up[I0]; + + idx_low += idx_diff_low; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return true; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Slice, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("slice_begin_ %d", index_t{slice_begin_}); + printf("slice_end %d", index_t{slice_end_}); + printf("}"); + } +}; + +/* + * \brief lower_idx = upper_idx % modulus. + * TODO: Need an improved implementation since the modulo operation is expensive. + */ +template +struct Modulo +{ + using LowerIndex = MultiIndex<1>; + using UpperIndex = MultiIndex<1>; + using UpLengths = decltype(make_tuple(UpLength{})); + + Modulus modulus_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Modulo() = default; + + __host__ __device__ constexpr Modulo(const Modulus& modulus, const UpLength& up_length) + : modulus_{modulus}, up_lengths_{make_tuple(up_length)} + { + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 1; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 1 && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}] % modulus_; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_diff_up, + LowIdx& idx_low, + const UpIdx& up_idx, + Number) const + { + static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 && + UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + + const auto idx_low_old = idx_low; + idx_low(I0) = (up_idx(I0) + idx_diff_up(I0)) % modulus_; + idx_diff_low(I0) = idx_low - idx_low_old; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Modulus, "); + printf("up_lengths_"); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +struct Xor +{ + using LowerIndex = MultiIndex<2>; + using UpperIndex = MultiIndex<2>; + + using UpLengths = LowLengths; + + UpLengths up_lengths_; + + __host__ __device__ constexpr Xor() : up_lengths_{} {} + + __host__ __device__ constexpr Xor(const LowLengths& low_lengths) : up_lengths_{low_lengths} {} + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return 2; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 2; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == 2 && UpIdx::Size() == 2, + "wrong! inconsistent # of dimension"); + + idx_low(Number<0>{}) = idx_up[Number<0>{}]; + + if constexpr(ApplyModulo) + { + idx_low(Number<1>{}) = + idx_up[Number<1>{}] ^ (idx_up[Number<0>{}] % up_lengths_[Number<1>{}]); + } + else + { + idx_low(Number<1>{}) = idx_up[Number<1>{}] ^ idx_up[Number<0>{}]; + } + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up, + Number) const + { + static_assert(LowIdxDiff::Size() == 2 && UpIdxDiff::Size() == 2 && LowIdx::Size() == 2 && + UpIdx::Size() == 2, + "wrong! inconsistent # of dimension"); + + const auto idx_low_old = idx_low; + + CalculateLowerIndex(idx_low, idx_up); + + idx_diff_low = idx_low - idx_low_old; + } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("Xor{"); + + // + printf("up_lengths_: "); + print(up_lengths_); + printf(", "); + + printf("}"); + } +}; +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_description/multi_index_transform_helper.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_description/multi_index_transform_helper.hpp new file mode 100755 index 000000000000..8feadf63c604 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_description/multi_index_transform_helper.hpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto make_pass_through_transform(const LowLength& low_length) +{ + return PassThrough{low_length}; +} + +template +__host__ __device__ constexpr auto +make_pad_transform(const LowLength& low_length, + const LeftPad& left_pad, + const RightPad& right_pad, + integral_constant = integral_constant{}) +{ + return Pad{low_length, left_pad, right_pad}; +} + +template +__host__ __device__ constexpr auto make_left_pad_transform( + const LowLength& low_length, + const LeftPadLength& left_pad, + integral_constant = integral_constant{}) +{ + return LeftPad{low_length, left_pad}; +} + +template +__host__ __device__ constexpr auto make_right_pad_transform( + const LowLength& low_length, + const RightPadLength& right_pad, + integral_constant = integral_constant{}) +{ + return RightPad{low_length, right_pad}; +} + +template ::type = false> +__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, + const Coefficients& coefficients) +{ + return Embed{up_lengths, coefficients}; +} + +template +__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) +{ +#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION + return make_merge_transform_v2_magic_division(low_lengths); +#else + return make_merge_transform_v1_carry_check(low_lengths); +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v1_carry_check(const LowLengths& low_lengths) +{ + return Merge_v1_carry_check{low_lengths}; +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v2_magic_division(const LowLengths& low_lengths) +{ +#if 1 + return Merge_v2_magic_division{low_lengths}; +#else + return Merge_v2r2_magic_division{low_lengths}; +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v3_division_mod(const LowLengths& low_lengths) +{ + return Merge_v3_division_mod{low_lengths}; +} + +template +__host__ __device__ constexpr auto make_unmerge_transform( + const UpLengths& up_lengths, + integral_constant = integral_constant{}) +{ + return UnMerge{up_lengths}; +} + +template +__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_idx) +{ + return Freeze{low_idx}; +} + +template +__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx) +{ + return Insert{up_idx}; +} + +template +__host__ __device__ constexpr auto make_slice_transform(const LowLength& low_length, + const SliceBegin& slice_begin, + const SliceEnd& slice_end) +{ + return Slice{low_length, slice_begin, slice_end}; +} + +template +__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size, + const UpLength& up_length) +{ + return Vectorize{vector_size, up_length}; +} + +template +__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, + const UpLength& up_length) +{ + return Modulo{modulus, up_length}; +} + +template +__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} + +template +__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths) +{ + return Xor{low_lengths}; +} +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_adaptor.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_adaptor.hpp new file mode 100755 index 000000000000..3ffac32469a8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_adaptor.hpp @@ -0,0 +1,482 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { + +// Transforms: Tuple +// LowerDimensionHiddenIdss : Tuple, ...> +// UpperDimensionHiddenIdss : Tuple, ...> +// BottomDimensionHiddenIds : Sequence<...> +// TopDimensionHiddenIds : Sequence<...> +template +struct TensorAdaptor +{ + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionHiddenIdss() + { + return LowerDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionHiddenIdss() + { + return UpperDimensionHiddenIdss{}; + } + + __host__ __device__ static constexpr auto GetTopDimensionHiddenIds() + { + return TopDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto GetBottomDimensionHiddenIds() + { + return BottomDimensionHiddenIds{}; + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_top) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_top); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_top = Number{}; + + constexpr index_t idim_hidden = TopDimensionHiddenIds::At(idim_top); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionHiddenIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + __host__ __device__ static constexpr index_t GetNumOfBottomDimension() + { + return BottomDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfTopDimension() + { + return TopDimensionHiddenIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + LowerDimensionHiddenIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, + UpperDimensionHiddenIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + constexpr static index_t ndim_bottom_ = GetNumOfBottomDimension(); + constexpr static index_t ndim_top_ = GetNumOfTopDimension(); + + using HiddenIndex = MultiIndex; + using BottomIndex = MultiIndex; + using TopIndex = MultiIndex; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: +#if 0 // workaround compiler complaint about constexpr + __host__ __device__ constexpr TensorAdaptor() = default; +#else + __host__ __device__ constexpr TensorAdaptor() : transforms_{}, element_size_{} {} +#endif + + __host__ __device__ constexpr TensorAdaptor(const Transforms& transforms) + : transforms_{transforms}, element_size_{InitializeElementSize(transforms)} + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionHiddenIdss::Size() == ntransform_ && + UpperDimensionHiddenIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + +#if 0 // debug + template + __host__ __device__ constexpr index_t GetTopDimensionLength(Number idim) const + { + // TODO: not implemented + } + + template + __host__ __device__ constexpr index_t GetBottomDimensionLength(Number idim) const + { + // TODO: not implemented + } +#endif + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == TopDimensionHiddenIds::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = GetNumOfTransform(); + constexpr index_t ndim_hidden = GetNumOfHiddenDimension(); + + MultiIndex idx_hidden; + + // initialize uppest index + set_container_subset(idx_hidden, GetTopDimensionHiddenIds(), idx_top); + + // calculate hidden index + static_for{}([&](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = GetTransforms().At(itran); + constexpr auto dims_low = GetLowerDimensionHiddenIdss().At(itran); + constexpr auto dims_up = GetUpperDimensionHiddenIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return get_container_subset(idx_hidden, BottomDimensionHiddenIds{}); + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= remove_cvref_t::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorAdaptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionHiddenIds:"); + LowerDimensionHiddenIdss{}.At(i).Print(); + printf("UpperDimensionHiddenIds:"); + UpperDimensionHiddenIdss{}.At(i).Print(); + }); + + printf("BottomDimensionHiddenIds:"); + BottomDimensionHiddenIds::Print(); + printf("TopDimensionHiddenIds:"); + TopDimensionHiddenIds::Print(); + + printf("}"); + } + + private: + Transforms transforms_; + ElementSize element_size_; +}; + +template +__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& adaptor0, + const TensorAdaptor1& adaptor1) +{ + static_assert(TensorAdaptor0::GetNumOfTopDimension() == + TensorAdaptor1::GetNumOfBottomDimension(), + "wrong!"); + + // all_transforms = transform0 + transform1 + const auto all_transforms = + container_concat(adaptor0.GetTransforms(), adaptor1.GetTransforms()); + + // shift + constexpr index_t adaptor0_max_hidden_id = [&]() { + index_t adaptor0_max_hidden_id_ = NumericLimits::Min(); + + static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value); + }); + + constexpr index_t ndim_up = + TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor0_max_hidden_id_ = + math::max(adaptor0_max_hidden_id_, + TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor0_max_hidden_id_; + }(); + + constexpr index_t adaptor1_min_hidden_id = [&]() { + index_t adaptor1_min_hidden_id_ = NumericLimits::Max(); + + static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) { + constexpr index_t ndim_low = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfLowerDimension(); + + // get the min of all lower dimenions, but not bottom dimension (because their id will + // be matched with top id from adaptor0) + static_for<0, ndim_low, 1>{}([&](auto idim_low) { + constexpr index_t low_dim_hidden_id = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran][idim_low].value; + + bool is_bottom_dim = false; + static_for<0, TensorAdaptor1::GetNumOfBottomDimension(), 1>{}([&](auto i) { + if constexpr(low_dim_hidden_id == + TensorAdaptor1::GetBottomDimensionHiddenIds()[i]) + { + is_bottom_dim = true; + } + }); + + if(!is_bottom_dim) + { + adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id); + } + }); + + constexpr index_t ndim_up = + TensorAdaptor1{}.GetTransforms()[itran].GetNumOfUpperDimension(); + + // get the min of all upper dimensions + static_for<0, ndim_up, 1>{}([&](auto idim_up) { + adaptor1_min_hidden_id_ = + math::min(adaptor1_min_hidden_id_, + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value); + }); + }); + + return adaptor1_min_hidden_id_; + }(); + + constexpr index_t adaptor1_hidden_id_shift = + adaptor0_max_hidden_id + 1 - adaptor1_min_hidden_id; + + constexpr index_t ndim_bottom_1 = TensorAdaptor1::GetNumOfBottomDimension(); + + // all_low_dim_hidden_idss = + // low_dim_hidden_idss_0 + match_hidden_id_for_1(shift_hidden_id_for_1(low_dim_hiden_idss_1)) + constexpr auto low_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_low_1 = TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran].Size(); + + constexpr auto low_dim_hidden_ids_1 = + TensorAdaptor1::GetLowerDimensionHiddenIdss()[itran]; + + // sequence in, sequence out + constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr + { + auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1); + + // shift hidden id so every dim id is unique + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift; + }); + + // match hidden id + static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) { + static_for<0, ndim_bottom_1, 1>{}([&](auto idim_bottom_1) { + // if this low dim is bottom dim, then do id matching + if constexpr(low_dim_hidden_ids_1[idim_low_1] == + TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1]) + { + low_dim_hidden_ids_1_mod_(idim_low_1) = + TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1]; + } + }); + }); + + return low_dim_hidden_ids_1_mod_; + } + (); + + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_low_dim_hidden_idss = + container_concat(TensorAdaptor0::GetLowerDimensionHiddenIdss(), low_dim_hidden_idss_1); + + // all_up_dim_hidden_idss = + // up_dim_hidden_idss_0 + shift_hidden_id_for_1(up_dim_hiden_idss_1) + constexpr auto up_dim_hidden_idss_1 = generate_tuple( + // generate sequence of ids for a transform + [&](auto itran) { + constexpr auto ndim_up_1 = TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran].Size(); + + constexpr auto up_dim_hidden_ids_1 = + TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran]; + + // sequence in, constexpr tuple out + constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr + { + auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1); + + // shift hidden id + static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) { + up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift; + }); + + return up_dim_hidden_ids_1_mod_; + } + (); + + // constexpr tuple to sequence + return generate_sequence_v2( + [&](auto i) constexpr { return Number{}; }, + Number{}); + }, + Number{}); + + constexpr auto all_up_dim_hidden_idss = + container_concat(TensorAdaptor0::GetUpperDimensionHiddenIdss(), up_dim_hidden_idss_1); + + // bottom_dim_hidden_ids = bottom_dim_hidden_ids_0 + constexpr auto bottom_dim_hidden_ids = TensorAdaptor0::GetBottomDimensionHiddenIds(); + + // top_dim_hidden_ids = shift_hidden_id(top_dim_hidden_ids_1) + constexpr auto top_dim_hidden_ids = + TensorAdaptor1::GetTopDimensionHiddenIds() + Number{}; + + // put everything together + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms}; +} + +// Transforms: Tuple +// LowerDimensionOldTopIdss: Tuple, ...> +// UpperDimensionNewTopIdss: Tuple, ...> +template +__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms& transforms, + LowerDimensionOldTopIdss, + UpperDimensionNewTopIdss) +{ + constexpr index_t ntransform = Transforms::Size(); + + static_assert(LowerDimensionOldTopIdss::Size() == ntransform && + UpperDimensionNewTopIdss::Size() == ntransform, + "wrong!"); + + // sanity check on LowerDimensionOldTopIdss and UpperDimensionNewTopIdss + constexpr auto all_low_dim_old_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionOldTopIdss{}); + + constexpr auto all_up_dim_new_top_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionNewTopIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + + constexpr index_t ndim_old_top = all_low_dim_old_top_ids.Size(); + constexpr index_t ndim_new_top = all_up_dim_new_top_ids.Size(); + + // low_dim_hidden_idss + constexpr auto low_dim_hidden_idss = LowerDimensionOldTopIdss{}; + + // up_dim_hidden_idss: shift UpperDimensionNewTopIdss by ndim_bottom + constexpr auto up_dim_hidden_idss = generate_tuple( + [](auto itran) { return UpperDimensionNewTopIdss{}[itran] + Number{}; }, + Number{}); + + // bottom_dim_hidden_ids + constexpr auto bottom_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_old_top, 1>::type{}; + + // top_dim_hidden_ids + constexpr auto top_dim_hidden_ids = + typename arithmetic_sequence_gen<0, ndim_new_top, 1>::type{} + Number{}; + + return TensorAdaptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms}; +} + +template = 2, bool>::type = false> +__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) +{ + return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_descriptor.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_descriptor.hpp new file mode 100755 index 000000000000..f1df2eedd466 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_descriptor.hpp @@ -0,0 +1,615 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/tensor_description/multi_index_transform.hpp" + +namespace ck { + +template +struct TensorCoordinate; + +template +struct TensorCoordinateStep; + +// Transforms: Tuple +// LowerDimensionIdss : Tuple, ...> +// UpperDimensionIdss : Tuple, ...> +// VisibleDimensionIds> : Sequence<...> +template +struct TensorDescriptor +{ + // TODO make these private + __host__ __device__ static constexpr index_t GetNumOfTransform() { return Transforms::Size(); } + + __host__ __device__ static constexpr index_t GetNumOfVisibleDimension() + { + return VisibleDimensionIds::Size(); + } + + __host__ __device__ static constexpr index_t GetNumOfHiddenDimension() + { + constexpr auto all_low_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{}); + + constexpr auto all_up_dim_ids = unpack( + [](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{}); + + constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids); + + using unique_sort_all_dim_ids = typename sequence_unique_sort, + math::equal>::type; + + return unique_sort_all_dim_ids::Size(); + } + + __host__ __device__ static constexpr auto InitializeElementSize(const Transforms& transforms) + { + const auto lengths = generate_tuple( + [&](auto idim_visible) { + constexpr auto tmp = GetTransformAndItsUpperDimension(idim_visible); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + const auto length = + transforms[Number{}].GetUpperLengths()[Number{}]; + + return length; + }, + Number{}); + + // TODO: make container_reduce support tuple of Number and index_t + return container_reduce(lengths, math::multiplies{}, Number<1>{}); + } + + template + __host__ __device__ static constexpr auto GetTransformAndItsUpperDimension(Number) + { + constexpr auto idim_visible = Number{}; + + constexpr index_t idim_hidden = VisibleDimensionIds::At(idim_visible); + + index_t itran_found = 0; + index_t idim_up_found = 0; + bool found = false; + + static_for<0, ntransform_, 1>{}([&](auto itran) { + constexpr auto up_dim_ids = UpperDimensionIdss{}[itran]; + + static_for<0, up_dim_ids.Size(), 1>{}([&](auto idim_up) { + if constexpr(up_dim_ids[idim_up] == idim_hidden) + { + itran_found = itran; + idim_up_found = idim_up; + found = true; + } + }); + }); + + return make_tuple(itran_found, idim_up_found, found); + } + + constexpr static index_t ntransform_ = GetNumOfTransform(); + constexpr static index_t ndim_visible_ = GetNumOfVisibleDimension(); + constexpr static index_t ndim_hidden_ = GetNumOfHiddenDimension(); + + using VisibleIndex = MultiIndex; + using HiddenIndex = MultiIndex; + using Coordinate = TensorCoordinate; + + // may be index_t or Number<> + using ElementSize = remove_cv_t; + + public: +#if 0 // workaround compiler complaint about constexpr + __host__ __device__ constexpr TensorDescriptor() = default; +#else + __host__ __device__ constexpr TensorDescriptor() + : transforms_{}, element_size_{}, element_space_size_{} + { + } +#endif + + __host__ __device__ constexpr TensorDescriptor(const Transforms& transforms, + ElementSpaceSize element_space_size) + : transforms_{transforms}, + element_size_{InitializeElementSize(transforms)}, + element_space_size_{element_space_size} + + { + static_assert(Transforms::Size() == ntransform_ && + LowerDimensionIdss::Size() == ntransform_ && + UpperDimensionIdss::Size() == ntransform_, + "wrong! inconsistent # of transformations"); + + // TODO check dependency of dimensions is valid + } + + __host__ __device__ static constexpr index_t GetNumOfDimension() + { + return GetNumOfVisibleDimension(); + } + + template + __host__ __device__ constexpr auto GetLength(Number) const + { + static_assert(IDim >= 0 && IDim < ndim_visible_, "wrong! out of range"); + + constexpr auto tmp = GetTransformAndItsUpperDimension(Number{}); + + constexpr index_t itran = tmp[Number<0>{}]; + constexpr index_t idim_up = tmp[Number<1>{}]; + constexpr bool found = tmp[Number<2>{}]; + + static_assert(found == true, + "wrong! not found matching transformation and upper-dimension"); + + return transforms_[Number{}].GetUpperLengths()[Number{}]; + } + + __host__ __device__ constexpr auto GetLengths() const + { + // FIXME: use Tuple of reference instead + return generate_sequence_v2([&](auto I) { return GetLength(I); }, Number{}); + } + + __host__ __device__ constexpr auto GetElementSize() const { return element_size_; } + + __host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; } + + template + __host__ __device__ constexpr index_t CalculateOffset(const Idx& idx) const + { + static_assert(Idx::Size() == GetNumOfDimension(), "wrong! inconsistent # of dimension"); + + return make_tensor_coordinate(*this, idx).GetOffset(); + } + + // TODO make these private + __host__ __device__ constexpr const auto& GetTransforms() const { return transforms_; } + + __host__ __device__ static constexpr auto GetLowerDimensionIdss() + { + return LowerDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetUpperDimensionIdss() + { + return UpperDimensionIdss{}; + } + + __host__ __device__ static constexpr auto GetVisibleDimensionIds() + { + return VisibleDimensionIds{}; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + bool is_known = true; + + static_for<0, Transforms::Size(), 1>{}([&](auto i) { + is_known &= remove_cvref_t::IsKnownAtCompileTime(); + }); + + return is_known && is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("TensorDescriptor, "); + static_for<0, ntransform_, 1>{}([&](auto i) { + printf("transforms: "); + transforms_[i].Print(); + printf("LowerDimensionIds:"); + LowerDimensionIdss{}.At(i).Print(); + printf("UpperDimensionIds:"); + UpperDimensionIdss{}.At(i).Print(); + }); + printf("}"); + + VisibleDimensionIds::Print(); + } + + // TODO make these private + Transforms transforms_; + ElementSize element_size_; + ElementSpaceSize element_space_size_; +}; + +template +struct TensorCoordinate +{ + // TODO make these private + static constexpr index_t ndim_visible_ = VisibleDimensionIds::Size(); + + using HiddenIndex = MultiIndex; + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr TensorCoordinate() = default; + + __host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) + : idx_hidden_{idx_hidden} + { + } + + __host__ __device__ constexpr auto GetIndex() const { return GetVisibleIndex(); } + + __host__ __device__ constexpr index_t GetOffset() const { return idx_hidden_[Number<0>{}]; } + + // TODO make these private + __host__ __device__ constexpr const auto& GetHiddenIndex() const { return idx_hidden_; } + + __host__ __device__ auto& GetHiddenIndex() { return idx_hidden_; } + + __host__ __device__ constexpr auto GetVisibleIndex() const + { + return get_container_subset(idx_hidden_, VisibleDimensionIds{}); + } + + // TODO make these private + HiddenIndex idx_hidden_; +}; + +template +struct TensorCoordinateStep +{ + // TODO make these private + using VisibleIndex = MultiIndex; + + public: + __host__ __device__ constexpr TensorCoordinateStep() = default; + + __host__ __device__ constexpr TensorCoordinateStep(const VisibleIndex& idx_diff_visible, + const MultiIndex& do_transforms) + : idx_diff_visible_{idx_diff_visible}, do_transforms_{do_transforms} + { + } + + __host__ __device__ constexpr const auto& GetIndexDiff() const { return GetVisibleIndexDiff(); } + + // TODO make these private + __host__ __device__ constexpr const auto& GetVisibleIndexDiff() const + { + return idx_diff_visible_; + } + + VisibleIndex idx_diff_visible_; + MultiIndex do_transforms_; + + // HACK: control UpdateLowerIndex() + static constexpr UpdateLowerIndexHack update_lower_index_hack_; +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor, and to put it outside the scope where it is used +// (transform_tensor_descriptor) because template cannot be defined inside a function +// template +template +struct lambda_get_up_dim_num +{ + template + __host__ __device__ constexpr auto operator()(I) const + { + using Tran = remove_reference_t; + return Number{}; + } +}; + +template +__host__ __device__ constexpr auto +transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc, + const NewTransforms& new_transforms, + NewLowerDimensionOldVisibleIdss, + NewUpperDimensionNewVisibleIdss) +{ + // sanity check + { + static_assert(NewTransforms::Size() == NewLowerDimensionOldVisibleIdss::Size() && + NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(), + "wrong! inconsitent number of transform"); + + constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + static_assert(is_valid_sequence_map::value && + is_valid_sequence_map::value, + "wrong!"); + } + + // lower dimension's hidden idss + // convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of + // sequences) + constexpr auto low_dim_hidden_idss = transform_tuples( + // convert lower dimension visible ids (a sequence) to hidden ids (a sequence) + [](auto low_dim_visible_ids) constexpr { + return transform_sequences( + // convert lower dimension visible id to hidden id + [](auto low_dim_visible_id) constexpr { + return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id]; + }, + low_dim_visible_ids); + }, + NewLowerDimensionOldVisibleIdss{}); + + constexpr index_t num_new_transform = NewTransforms::Size(); + + // upper dimension's hidden idss + constexpr index_t old_hidden_dim_number = OldTensorDescriptor::GetNumOfHiddenDimension(); + + constexpr auto up_dim_numbers = + generate_sequence(lambda_get_up_dim_num{}, Number{}); + + constexpr auto up_dim_numbers_scan = merge_sequences( + Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus{}, Number<0>{})); + + constexpr auto up_dim_hidden_idss = generate_tuple( + [ old_hidden_dim_number, up_dim_numbers_scan ](auto i) constexpr { + return + typename arithmetic_sequence_gen::type{}; + }, + Number{}); + + // new visible dimension's hidden ids + constexpr auto unordered_new_visible_dim_hidden_ids = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss); + + constexpr auto new_visible_dim_unordered2ordered = unpack( + [](auto... xs) constexpr { return merge_sequences(xs...); }, + NewUpperDimensionNewVisibleIdss{}); + + constexpr auto new_visible_dim_hidden_ids = + unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered); + + // put everything together + const auto all_transforms = container_concat(old_tensor_desc.GetTransforms(), new_transforms); + + constexpr auto all_low_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetLowerDimensionIdss(), low_dim_hidden_idss); + + constexpr auto all_up_dim_hidden_idss = + container_concat(OldTensorDescriptor::GetUpperDimensionIdss(), up_dim_hidden_idss); + + const auto element_space_size = old_tensor_desc.GetElementSpaceSize(); + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{all_transforms, + element_space_size}; +} + +template +__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc, + const VisibleIndex& idx_visible) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + MultiIndex idx_hidden; + + // initialize visible index + set_container_subset(idx_hidden, visible_dim_ids, idx_visible); + + // calculate hidden index + static_for{}([&tensor_desc, &idx_hidden](auto itran_p1) { + auto itran = itran_p1 - Number<1>{}; + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up = get_container_subset(idx_hidden, dims_up); + + MultiIndex idx_low; + + tran.CalculateLowerIndex(idx_low, idx_up); + + set_container_subset(idx_hidden, dims_low, idx_low); + }); + + return TensorCoordinate{idx_hidden}; +} + +// UpdateLowerIndexHack: Sequence<...> +// HACK: control UpdateLowerIndex +template +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible, + UpdateLowerIndexHack) +{ + static_assert(TensorDesc::GetNumOfDimension() == VisibleIndex::Size(), + "wrong! # of dimension inconsistent"); + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension(); + constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds(); + + static_assert(UpdateLowerIndexHack::Size() == ntransform, "wrong!"); + + // use index_t for boolean type + auto do_transforms = make_zero_multi_index(); + auto is_non_zero_diff = make_zero_multi_index(); + + // decide do_transform by checkout non-zero index diff components + MultiIndex non_zero_diff_pick_visible; + + static_for<0, ndim_visible, 1>{}( + [&](auto i) { non_zero_diff_pick_visible(i) = (idx_diff_visible[i] != 0); }); + + set_container_subset(is_non_zero_diff, visible_dim_ids, non_zero_diff_pick_visible); + + static_for{}([&](auto itran) { + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto non_zero_diff_pick_up = get_container_subset(is_non_zero_diff, dims_up); + + MultiIndex non_zero_diff_pick_low; + + // if any of upper index diff components is non-zero, then + // 1) Need to do this transform + // 2) all components of lower index diff will assume to be non-zero and need to be + // computed + const bool idx_diff_up_has_non_zero = container_reduce( + non_zero_diff_pick_up, [](auto a, auto b) constexpr { return a or b; }, false); + + do_transforms(itran) = idx_diff_up_has_non_zero; + + static_for<0, dims_low.Size(), 1>{}( + [&](auto i) { non_zero_diff_pick_low(i) = idx_diff_up_has_non_zero; }); + + set_container_subset(is_non_zero_diff, dims_low, non_zero_diff_pick_low); + }); + + return TensorCoordinateStep{idx_diff_visible, + do_transforms}; +} + +template +__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc&, + const VisibleIndex& idx_diff_visible) +{ + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + return make_tensor_coordinate_step( + TensorDesc{}, idx_diff_visible, typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc& tensor_desc, + TensorCoord& coord, + const TensorCoordStep& coord_step) +{ + constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension(); + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + // this is what needs to be calculated + auto idx_diff_hidden = make_zero_multi_index(); + + // initialize visible index diff + set_container_subset( + idx_diff_hidden, TensorDesc::GetVisibleDimensionIds(), coord_step.GetVisibleIndexDiff()); + + // this is what needs to be updated + auto& idx_hidden = coord.GetHiddenIndex(); + + // update visible index + auto idx_hidden_pick_visible = + get_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds()); + + idx_hidden_pick_visible += coord_step.GetIndexDiff(); + + set_container_subset(idx_hidden, TensorDesc::GetVisibleDimensionIds(), idx_hidden_pick_visible); + + // update rest of hidden index + static_for{}([&](auto itran) { + if(coord_step.do_transforms_[itran]) + { + const auto& tran = tensor_desc.GetTransforms().At(itran); + constexpr auto dims_low = TensorDesc::GetLowerDimensionIdss().At(itran); + constexpr auto dims_up = TensorDesc::GetUpperDimensionIdss().At(itran); + + const auto idx_up_new = get_container_subset(idx_hidden, dims_up); + auto idx_low = get_container_subset(idx_hidden, dims_low); + const auto idx_diff_up = get_container_subset(idx_diff_hidden, dims_up); + + MultiIndex idx_diff_low; + + // HACK: control UpdateLowerIndex for Merge using hack + constexpr index_t Hack = decltype(coord_step.update_lower_index_hack_)::At(itran); + + tran.UpdateLowerIndex(idx_diff_low, idx_diff_up, idx_low, idx_up_new, Number{}); + + set_container_subset(idx_diff_hidden, dims_low, idx_diff_low); + set_container_subset(idx_hidden, dims_low, idx_low); + } + }); +} + +template +__host__ __device__ constexpr bool +coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + bool valid = true; + + constexpr index_t ntransform = TensorDesc::GetNumOfTransform(); + + const auto& idx_hidden = coord.GetHiddenIndex(); + + static_for{}([&tensor_desc, &idx_hidden, &valid](auto itran) { + const auto tran = tensor_desc.GetTransforms().At(itran); + + // check validity, only if current transformation does not always has a valid mapping + if constexpr(!decltype(tran)::IsValidUpperIndexAlwaysMappedToValidLowerIndex()) + { + const auto idx_up = + get_container_subset(idx_hidden, TensorDesc::GetUpperDimensionIdss().At(itran)); + + // Comment: using valid = valid && .. will result in weird control flow in ISA + valid &= tran.IsValidUpperIndexMappedToValidLowerIndex(idx_up); + } + }); + + return valid; +} + +template +__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc, + const TensorCoord& coord) +{ + // check visible index + const auto& idx_visible = coord.GetVisibleIndex(); + + bool is_visible_index_valid = true; + + static_for<0, TensorDesc::GetNumOfDimension(), 1>{}( + [&is_visible_index_valid, &idx_visible, &tensor_desc](auto i) { + is_visible_index_valid = + is_visible_index_valid && + (idx_visible[i] >= 0 && idx_visible[i] < tensor_desc.GetLength(i)); + }); + + // check other hidden index + return is_visible_index_valid && + coordinate_has_valid_offset_assuming_visible_index_is_valid(tensor_desc, coord); +} + +template +using TensorCoordinate_t = decltype(make_tensor_coordinate( + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); + +template +using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_descriptor_helper.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_descriptor_helper.hpp new file mode 100755 index 000000000000..f3ac041bf9b8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_descriptor_helper.hpp @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" + +namespace ck { + +/* + * These functions create tensor descriptor at runtime. If they are not constexpr, you will + * likely see usage of scratch memory during construction of these tensor descriptors. So + * it's better to call these functions on host and then pass the constructed tensor descritpors + * to GPU. If the tensor descritpors being constructed are constexpr, then you can call these + * functions on GPU without worrying about scratch memory usage. + */ + +#if CK_WORKAROUND_SWDEV_275126 +template +__host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengths& lengths, + const Strides& strides, + Number i, + AccOld acc_old) +{ + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < Lengths::Size() - 1) + { + return calculate_element_space_size_impl(lengths, strides, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } +} +#endif + +// Lengths..., Strides... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> +template ::type = false> +__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple& lengths, + const Tuple& strides) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_embed_transform(lengths, strides)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + +#if !CK_WORKAROUND_SWDEV_275126 + // rocm-4.1 compiler would crash for recursive labmda + // recursive function for reduction + auto f = [&](auto fs, auto i, auto acc_old) { + auto acc_new = acc_old + (lengths[i] - Number<1>{}) * strides[i]; + + if constexpr(i.value < N - 1) + { + return fs(fs, i + Number<1>{}, acc_new); + } + else + { + return acc_new; + } + }; + + const auto element_space_size = f(f, Number<0>{}, LongNumber<1>{}); +#else + const auto element_space_size = + calculate_element_space_size_impl(lengths, strides, Number<0>{}, LongNumber<1>{}); +#endif + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// element_space_size could be: +// 1) long_index_t, or +// 2) LongNumber<> +template +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_packed(const Tuple& lengths) +{ + constexpr index_t N = sizeof...(Lengths); + + const auto transforms = make_tuple(make_unmerge_transform(lengths)); + + constexpr auto low_dim_hidden_idss = make_tuple(Sequence<0>{}); + + constexpr auto up_dim_hidden_idss = + make_tuple(typename arithmetic_sequence_gen<1, N + 1, 1>::type{}); + + constexpr auto visible_dim_hidden_ids = typename arithmetic_sequence_gen<1, N + 1, 1>::type{}; + + const auto element_space_size = container_reduce(lengths, math::multiplies{}, LongNumber<1>{}); + + return TensorDescriptor, + remove_cv_t, + remove_cv_t, + remove_cv_t, + remove_cv_t>{transforms, + element_space_size}; +} + +// Lengths... could be: +// 1) index_t, which is known at run-time, or +// 2) Number<>, which is known at compile-time +// align could be: +// 1) index_t, or +// 2) Number<> +template +__host__ __device__ constexpr auto +make_naive_tensor_descriptor_aligned(const Tuple& lengths, Align align) +{ + constexpr auto I1 = Number<1>{}; + + constexpr index_t N = sizeof...(Lengths); + + const auto stride_n_minus_2 = math::integer_least_multiple(lengths[Number{}], align); + + auto strides = generate_tuple( + [&](auto i) { + if constexpr(i.value == N - 1) + { + return I1; + } + else if constexpr(i.value == N - 2) + { + return Number{}; + } + else + { + return container_reduce(lengths, + math::multiplies{}, + Number{}, + i + I1, + Number{}, + I1); + } + }, + Number{}); + + return make_naive_tensor_descriptor(lengths, strides); +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_space_filling_curve.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_space_filling_curve.hpp new file mode 100755 index 000000000000..9a326092d2e0 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_description/tensor_space_filling_curve.hpp @@ -0,0 +1,162 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/statically_indexed_array_multi_index.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template // # of scalars per access in each dimension +struct SpaceFillingCurve +{ + static constexpr index_t nDim = TensorLengths::Size(); + + using Index = MultiIndex; + + static constexpr index_t ScalarPerVector = + reduce_on_sequence(ScalarsPerAccess{}, math::multiplies{}, Number<1>{}); + + static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{}; + static constexpr auto dim_access_order = DimAccessOrder{}; + static constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(ordered_access_lengths)), + make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr index_t GetNumOfAccess() + { + static_assert(TensorLengths::Size() == ScalarsPerAccess::Size()); + static_assert(TensorLengths{} % ScalarsPerAccess{} == + typename uniform_sequence_gen::type{}); + + return reduce_on_sequence(TensorLengths{}, math::multiplies{}, Number<1>{}) / + ScalarPerVector; + } + + template + static __device__ __host__ constexpr auto GetStepBetween(Number, + Number) + { + static_assert(AccessIdx1dBegin >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dBegin < GetNumOfAccess(), "1D index should be larger than 0"); + static_assert(AccessIdx1dEnd >= 0, "1D index should be non-negative"); + static_assert(AccessIdx1dEnd < GetNumOfAccess(), "1D index should be larger than 0"); + + constexpr auto idx_begin = GetIndex(Number{}); + constexpr auto idx_end = GetIndex(Number{}); + return idx_end - idx_begin; + } + + template + static __device__ __host__ constexpr auto GetForwardStep(Number) + { + static_assert(AccessIdx1d < GetNumOfAccess(), "1D index should be larger than 0"); + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr auto GetBackwardStep(Number) + { + static_assert(AccessIdx1d > 0, "1D index should be larger than 0"); + + return GetStepBetween(Number{}, Number{}); + } + + template + static __device__ __host__ constexpr Index GetIndex(Number) + { +#if 0 + /* + * \todo: TensorAdaptor::CalculateBottomIndex does NOT return constexpr as expected. + */ + constexpr auto ordered_access_idx = to_index_adaptor.CalculateBottomIndex(make_multi_index(Number{})); +#else + + constexpr auto access_strides = container_reverse_exclusive_scan( + ordered_access_lengths, math::multiplies{}, Number<1>{}); + + constexpr auto idx_1d = Number{}; + // Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the + // idim-th element of multidimensional index. + // All constexpr variables have to be captured by VALUE. + constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr + { + constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr + { + auto res = idx_1d.value; + auto id = 0; + + static_for<0, jdim.value + 1, 1>{}([&](auto kdim) { + id = res / access_strides[kdim].value; + res -= id * access_strides[kdim].value; + }); + + return id; + }; + + constexpr auto id = compute_index_impl(idim); + return Number{}; + }; + + constexpr auto ordered_access_idx = generate_tuple(compute_index, Number{}); +#endif + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto idim) { + index_t tmp = ordered_access_idx[I0]; + + static_for<1, idim, 1>{}( + [&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; }); + + forward_sweep_(idim) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate multi-dim tensor index + auto idx_md = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto idim) { + ordered_idx(idim) = + !SnakeCurved || forward_sweep[idim] + ? ordered_access_idx[idim] + : ordered_access_lengths[idim] - 1 - ordered_access_idx[idim]; + }); + + return container_reorder_given_old2new(ordered_idx, dim_access_order) * + ScalarsPerAccess{}; + }(); + return idx_md; + } + + // FIXME: rename this function + template + static __device__ __host__ constexpr auto GetIndexTupleOfNumber(Number) + { + constexpr auto idx = GetIndex(Number{}); + + return generate_tuple([&](auto i) { return Number{}; }, Number{}); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp new file mode 100755 index 000000000000..f23404a1d7ba --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp" + +namespace ck { + +// C[BM0, BM1, BN0, BN1] += transpose(A[K, BM0, BM1]) * B[K, BN0, BN1] +// A and B are visible to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. ABlockDesc_BK0_BM_BK1 is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BBlockDesc_BK0_BN_BK1 is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2 +// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template + typename BM10BN10ThreadClusterBN10Xs, // Sequence + index_t AThreadCopyScalarPerVector_BM11, + index_t BThreadCopyScalarPerVector_BN11, + typename enable_if::type = false> +struct BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t BK0 = ABlockDesc_BK0_BM_BK1{}.GetLength(I0); + static constexpr index_t BK1 = ABlockDesc_BK0_BM_BK1{}.GetLength(I2); + static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1); + static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1); + + static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0]; + static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0]; + + static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1]; + static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1]; + + static constexpr index_t BM11 = BM1PerThreadBM11; + static constexpr index_t BN11 = BN1PerThreadBN11; + + static constexpr index_t BM1 = BM100 * BM101 * BM11; + static constexpr index_t BN1 = BN100 * BN101 * BN11; + + static constexpr index_t BM0 = BM / BM1; + static constexpr index_t BN0 = BN / BN1; + + __host__ __device__ static constexpr auto + MakeABlockDescriptor_BK0_BM0_BM1_BK1(const ABlockDesc_BK0_BM_BK1& a_block_desc_bk0_bm_bk1) + { + const auto a_block_bk0_bm0_bm1_bk1 = transform_tensor_descriptor( + a_block_desc_bk0_bm_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_block_bk0_bm0_bm1_bk1; + } + + __host__ __device__ static constexpr auto + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(const BBlockDesc_BK0_BN_BK1& b_block_desc_bk0_bn_bk1) + { + const auto b_block_desc_bk0_bn0_bn1_bk1 = transform_tensor_descriptor( + b_block_desc_bk0_bn_bk1, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_block_desc_bk0_bn0_bn1_bk1; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM_BN() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM, BN] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n; + } + + __host__ __device__ static constexpr auto + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1() + { + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // lower: [BM0, BM1, BN0, BN1] + constexpr auto c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1 = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_block_adaptor_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1; + } + + __host__ __device__ static constexpr auto GetCThreadTensorLengths_BM0_BM1_BN0_BN1() + { + return Sequence{}; + } + + static constexpr auto a_block_desc_bk0_bm0_bm1_bk1_ = + MakeABlockDescriptor_BK0_BM0_BM1_BK1(ABlockDesc_BK0_BM_BK1{}); + + static constexpr auto b_block_desc_bk0_bn0_bn1_bk1_ = + MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{}); + + public: + __device__ BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2() + : c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)} + { + static_assert(ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && + BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == BM101 * BM100 * BN101 * BN100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(BM % BM1 == 0 && BN % BN1 == 0, "wrong!"); + + static_assert(ABlockDesc_BK0_BM_BK1{}.GetLength(I0) == + BBlockDesc_BK0_BN_BK1{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO remove this restriction + static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 && + BM10BN10ThreadClusterBN10Xs::Size() == 2, + "wrong!"); + + // TODO: remove this restriction + static_assert(BM0 == 2, "wrong"); + static_assert(BM0 == 2 && BN0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(index_t thread_id) + { + // lower: [BM0, BM1, BN0, BN1] + // upper: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + constexpr auto adaptor0 = + MakeCBlockAdaptor_BM0_BM100_BM101_BM11_BN0_BN100_BN101_BN11_To_BM0_BM1_BN0_BN1(); + + // lower: [BM0, BM100, BM101, BM11, BN0, BN100, BN101, BN11] + // upper: [Tid, BM0, BM11, BN0, BN11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(BM100, BN100, BM101, BN101)), + make_pass_through_transform(BM0), + make_pass_through_transform(BM11), + make_pass_through_transform(BN0), + make_pass_through_transform(BN11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + template + __device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CThreadDesc_BM0_BM11_BN0_BN11::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(BM0 == 2 && BN0 == 2 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I0) == BM0 && + CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize()); + + constexpr auto threadwise_contraction = + ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1< + FloatA, + FloatB, + FloatC, + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + CThreadDesc_BM0_BM11_BN0_BN11, + Sequence, + Sequence<1, BM1PerThreadBM11>, + Sequence<1, BN1PerThreadBN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of bk0 + static_for{}([&](auto bk0) { + // read A_sub_0 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I0, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I0, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_block_desc_bk0_bn0_bn1_bk1_, + make_tuple(bk0, I1, I0, I0), + b_block_buf, + b_thread_desc_bk0_bn0_bn1_bk1_, + make_tuple(I0, I1, I0, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_block_desc_bk0_bm0_bm1_bk1_, + make_tuple(bk0, I1, I0, I0), + a_block_buf, + a_thread_desc_bk0_bm0_bm1_bk1_, + make_tuple(I0, I1, I0, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_contraction.Run(a_thread_buf, + make_tuple(I0, I1, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[BK0, BM0, BM1, BK1] + static constexpr auto a_thread_desc_bk0_bm0_bm1_bk1_ = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + // B[BK0, BN0, BN1, BK1] + static constexpr auto b_thread_desc_bk0_bn0_bn1_bk1_ = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatA, + FloatA, + decltype(a_block_desc_bk0_bm0_bm1_bk1_), + decltype(a_thread_desc_bk0_bm0_bm1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BM1PerThreadBM11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4r1< + FloatB, + FloatB, + decltype(b_block_desc_bk0_bn0_bn1_bk1_), + decltype(b_thread_desc_bk0_bn0_bn1_bk1_), + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + Sequence<1, 1, BN1PerThreadBN11, BK1>, // SrcVectorTensorLengths + Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp new file mode 100755 index 000000000000..b0143366c1de --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp @@ -0,0 +1,397 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP + +#include "common_header.hpp" +#include "tensor_adaptor.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_contraction_dlops.hpp" + +namespace ck { + +// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1] +// A and B are visable to the whole block, C is distributed among each thread +// Assume: +// 1. A: +// 1. AKMBlockDesc is known at compile-time +// 2. ABlockBuffer is DynamicBuffer +// 2. B: +// 1. BKNBlockDesc is known at compile-time +// 2. BBlockBuffer is DynamicBuffer +// 3. C: +// 1. CM0M1N0N1ThreadDesc is known at compile-time +// 2. CThreadBuffer is StaticBuffer +// Also assume: +// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) +template < + index_t BlockSize, + typename FloatA, + typename FloatB, + typename FloatC, + typename AKMBlockDesc, + typename BKNBlockDesc, + index_t M1PerThreadM11, + index_t N1PerThreadN11, + index_t KPerThread, + index_t M1N1ThreadClusterM100, + index_t M1N1ThreadClusterN100, + index_t M1N1ThreadClusterM101, + index_t M1N1ThreadClusterN101, + index_t AThreadCopyScalarPerVector_M11, + index_t BThreadCopyScalarPerVector_N11, + typename enable_if::type = false> +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 +{ + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t K = AKMBlockDesc{}.GetLength(I0); + static constexpr index_t M = AKMBlockDesc{}.GetLength(I1); + static constexpr index_t N = BKNBlockDesc{}.GetLength(I1); + + static constexpr index_t M100 = M1N1ThreadClusterM100; + static constexpr index_t N100 = M1N1ThreadClusterN100; + + static constexpr index_t M101 = M1N1ThreadClusterM101; + static constexpr index_t N101 = M1N1ThreadClusterN101; + + static constexpr index_t M11 = M1PerThreadM11; + static constexpr index_t N11 = N1PerThreadN11; + + static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11; + static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11; + + static constexpr index_t M0 = M / M1; + static constexpr index_t N0 = N / N1; + + __host__ __device__ static constexpr auto + MakeAKM0M1BlockDescriptor(const AKMBlockDesc& /* a_k_m_block_desc */) + { + const auto a_k_m0_m1_block_desc = transform_tensor_descriptor( + AKMBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return a_k_m0_m1_block_desc; + } + + __host__ __device__ static constexpr auto + MakeBKN0N1BlockDescriptor(const BKNBlockDesc& /* b_k_n_block_desc */) + { + const auto b_k_n0_n1_block_desc = transform_tensor_descriptor( + BKNBlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return b_k_n0_n1_block_desc; + } + + __host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M, N] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor; + } + + __host__ __device__ static constexpr auto + MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor() + { + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + // lower: [M0, M1, N0, N1] + constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{})); + + return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor; + } + + __host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths() + { + return Sequence{}; + } + + static constexpr auto a_k_m0_m1_block_desc_ = MakeAKM0M1BlockDescriptor(AKMBlockDesc{}); + static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{}); + + public: + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2() + : c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock( + get_thread_local_1d_id())}, + a_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])}, + b_thread_copy_{ + make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])} + { + static_assert(AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == M101 * M100 * N101 * N100, + "wrong! blocksize and cluster size not consistent"); + + static_assert(M % M1 == 0 && N % N1 == 0, "wrong!"); + + static_assert(AKMBlockDesc{}.GetLength(I0) == BKNBlockDesc{}.GetLength(I0), + "wrong! K dimension not consistent"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2, "wrong"); + } + + __device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id) + { + // lower: [M0, M1, N0, N1] + // upper: [M0, M100, M101, M11, N0, N100, N101, N11] + constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor(); + + // lower: [M0, M100, M101, M11, N0, N100, N101, N11] + // upper: [Tid, M0, M11, N0, N11] + constexpr auto adaptor1 = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)), + make_pass_through_transform(M0), + make_pass_through_transform(M11), + make_pass_through_transform(N0), + make_pass_through_transform(N11)), + make_tuple( + Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1); + + return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0)); + } + + __host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; } + + __host__ __device__ static constexpr auto GetBBlockAlignment() { return N1PerThreadN11; } + + template + __device__ void Run(const CM0M1N0N1ThreadDesc& /* c_m0_m1_n0_n1_thread_desc */, + const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: remove this restriction + static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 && + CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0, + "wrong"); + + auto a_thread_buf = make_static_buffer( + a_k_m0_m1_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_k_n0_n1_thread_desc_.GetElementSpaceSize()); + + constexpr auto threadwise_gemm = + ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1, + Sequence<1, M1PerThreadM11>, + Sequence<1, N1PerThreadN11>>{}; + + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(I0, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(I0, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + + // loop over rest of k + static_for{}([&](auto k) { + // read A_sub_0 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I0, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I0, I0), + a_thread_buf); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // read B_sub_0 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I0, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I0, I0), + b_thread_buf); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + + // read B_sub_1 + b_thread_copy_.Run(b_k_n0_n1_block_desc_, + make_tuple(k, I1, I0), + b_block_buf, + b_k_n0_n1_thread_desc_, + make_tuple(I0, I1, I0), + b_thread_buf); + + // read A_sub_1 + a_thread_copy_.Run(a_k_m0_m1_block_desc_, + make_tuple(k, I1, I0), + a_block_buf, + a_k_m0_m1_thread_desc_, + make_tuple(I0, I1, I0), + a_thread_buf); + + // C_sub_00 += transpose(A_sub_0) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I0, I0, I0, I0)); + + // C_sub_01 += transpose(A_sub_0) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I0, I0, I1, I0)); + }); + + // C_sub_10 += transpose(A_sub_1) * B_sub_0 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I0, I0), + c_thread_buf, + make_tuple(I1, I0, I0, I0)); + + // C_sub_11 += transpose(A_sub_1) * B_sub_1 + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I1, I0), + b_thread_buf, + make_tuple(I0, I1, I0), + c_thread_buf, + make_tuple(I1, I0, I1, I0)); + } + + private: + // A[K, M0, M1] + static constexpr auto a_k_m0_m1_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // B[K, N0, N1] + static constexpr auto b_k_n0_n1_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + AThreadCopyScalarPerVector_M11, + 1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + BThreadCopyScalarPerVector_N11, + 1>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp new file mode 100755 index 000000000000..0d092da5168c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP +#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "threadwise_gemm_dlops_v3.hpp" + +namespace ck { + +template +struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + using AIndex = MultiIndex<3>; + using BIndex = MultiIndex<3>; + using CIndex = MultiIndex<4>; + + static constexpr auto E1 = ABlockDesc_E1_K1_E2{}.GetLength(I0); + static constexpr auto KPerBlock = ABlockDesc_E1_K1_E2{}.GetLength(I1); + static constexpr auto E2 = ABlockDesc_E1_K1_E2{}.GetLength(I2); + + static constexpr auto HoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + static constexpr auto WoPerBlock = BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + static constexpr auto KPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I0); + static constexpr auto HoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I2); + static constexpr auto WoPerThread = CThreadDesc_K_N_Ho_Wo{}.GetLength(I3); + + static constexpr auto a_thread_mtx_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + static constexpr auto b_thread_mtx_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number<1>{}, + Number{}, + Number{}, + Number{})); + + static constexpr auto c_thread_mtx_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number<1>{}, Number{}, Number{})); + + __device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3() + : c_thread_origin_data_idx_{GetBeginOfCThreadDesc_K_N_Ho_Wo(get_thread_local_1d_id())}, + a_thread_copy_{make_tuple(0, c_thread_origin_data_idx_[I0] * KPerThread, 0)} + { + static_assert(ABlockDesc_E1_K1_E2::IsKnownAtCompileTime() && + BBlockDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert( + ABlockDesc_E1_K1_E2{}.GetLength(I0) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I0) && + ABlockDesc_E1_K1_E2{}.GetLength(I2) == BBlockDesc_E1_N_Ho_Wo_E2{}.GetLength(I4), + "wrong! E dimension not consistent\n"); + + static_assert(E1 % EPerThreadLoop == 0, ""); + static_assert(KPerThread % KPerThreadLoop == 0, ""); + + static_assert(KPerBlock % KPerThread == 0 && HoPerBlock % HoPerThread == 0 && + WoPerBlock % WoPerThread == 0, + "wrong! Cannot evenly divide work among\n"); + + constexpr auto KThreadCluster = KPerBlock / KPerThread; + constexpr auto HThreadCluster = HoPerBlock / HoPerThread; + constexpr auto WThreadCluster = WoPerBlock / WoPerThread; + + static_assert(BlockSize == KThreadCluster * HThreadCluster * WThreadCluster, + "wrong! wrong blocksize\n"); + } + + __device__ static constexpr auto GetCThreadDesc_K_N_Ho_WoLengths() + { + return Sequence{}; + } + + __device__ static CIndex GetBeginOfCThreadDesc_K_N_Ho_Wo(index_t thread_id) + { + constexpr auto K0 = KPerBlock / KPerThread; + constexpr auto N0 = I1; + constexpr auto H0 = HoPerBlock / HoPerThread; + constexpr auto W0 = WoPerBlock / WoPerThread; + + constexpr auto c_threadid_to_k_n_h_w_thread_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(K0, N0, H0, W0))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto c_k_n_h_w_thread_cluster_idx = + c_threadid_to_k_n_h_w_thread_cluster_adaptor.CalculateBottomIndex( + make_multi_index(thread_id)); + + return c_k_n_h_w_thread_cluster_idx; + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BThreadBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto a_block_mtx = ABlockDesc_E1_K1_E2{}; + + // thread A buffer for GEMM + StaticBuffer + a_thread_buf; + + constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3{}; + + static_for<0, E1, EPerThreadLoop>{}([&](auto e_begin) { + static_for<0, KPerThread, KPerThreadLoop>{}([&](auto k_begin) { + a_thread_copy_.Run(a_block_mtx, + make_tuple(e_begin, k_begin, I0), + a_block_buf, + a_thread_mtx_, + make_tuple(I0, I0, I0), + a_thread_buf); + + threadwise_gemm.Run(a_thread_buf, + make_tuple(I0, I0, I0), + b_thread_buf, + make_tuple(e_begin, I0, I0, I0, I0), + c_thread_buf, + make_tuple(k_begin, I0, I0, I0)); + }); + }); + } + + template + __device__ void MoveABlockSliceWindow(const ABlockSliceMoveStepIdx& a_block_slice_move_step_idx) + { + a_thread_copy_.MoveSrcSliceWindow(ABlockDesc_E1_K1_E2{}, a_block_slice_move_step_idx); + } + + private: + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2>, + 2, + E2, + E2>; + + CIndex c_thread_origin_data_idx_; + + AThreadCopy a_thread_copy_; +}; + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp new file mode 100755 index 000000000000..f03427a7ead1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp @@ -0,0 +1,348 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/dpp_gemm.hpp" + +namespace ck { + +/** + * Blockwise GEMM that uses DPP instruction modifier to limit the amount of data loaded for each + * thread by sharing the data between threads in a lanegroup. + * + * In every iteration, each wave calculates a C tile of size `MPerDpp` * `NPerDpp`, there are + * `MRepeat` iterations for `M` dimension and `NRepeat` for `N` one. + * In total, the algorithm runs using + * `MPerBlock / (MRepeat * MPerDpp) * NPerBlock / (NRepeat * NPerDpp)` waves. + */ +template +struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto dpp_gemm = DppGemm{}; + + static constexpr index_t KPerThread = KPerBlock / dpp_gemm.K0PerDpp; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerDpp); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerDpp); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex_M0_M1_M2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto dpp_a_idx = dpp_gemm.CalculateAThreadOriginDataIndex_K_M(); + const auto dpp_a_idx_k = dpp_a_idx[I0]; + const auto dpp_a_idx_m = dpp_a_idx[I1]; + return make_tuple(0, waveId_m, dpp_a_idx_m, KPerThread * dpp_a_idx_k); + } + + __device__ static auto CalculateBThreadOriginDataIndex_N0_N1_N2_K() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto dpp_b_idx = dpp_gemm.CalculateBThreadOriginDataIndex_K_N(); + const auto dpp_b_idx_k = dpp_b_idx[I0]; + const auto dpp_b_idx_n = dpp_b_idx[I1]; + return make_tuple(0, waveId_n, dpp_b_idx_n, KPerThread * dpp_b_idx_k); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = dpp_gemm.GetBeginOfThreadBlk(); + const auto blk_m_offset = blk_idx[I0]; + const auto blk_n_offset = blk_idx[I1]; + + constexpr auto mrepeat_mwave_MPerDpp_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_NPerDpp_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerDpp))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_MPerDpp_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_m_offset))[I0]; + const index_t c_thread_n = nrepeat_nwave_NPerDpp_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_n_offset))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "Wrong! Block descriptors should be known at the time of compilation."); + +#if defined(__HIP_DEVICE_COMPILE__) + // Host wave size can be different than the device one and this assert could fail for host, + // but it does matter only for device. + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); +#endif + + static_assert(MPerBlock % (MPerDpp * MRepeat) == 0, + "Invalid parameters. MPerBlock must be divisible by MPerDpp * MRepeat."); + static_assert(NPerBlock % (NPerDpp * NRepeat) == 0, + "Invalid parameters. NPerBlock must be divisible by NPerDpp * NRepeat."); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_m_n_tblk_lens = dpp_gemm.GetCMNThreadBlkLengths(); + constexpr auto M = c_m_n_tblk_lens[I0]; + constexpr auto N = c_m_n_tblk_lens[I1]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return c_block_desc_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + return c_block_desc_g_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return c_grid_desc_m0_n0_m1_n1_m2_n2; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerDpp), MWaves, MPerDpp)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerDpp), NWaves, NPerDpp))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return c_grid_desc_g_m0_n0_m1_n1_m2_n2; + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using dpp_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + dpp_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegDpp] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, dpp_gemm.GetRegSizePerDpp())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex_M0_M1_M2_K()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex_N0_N1_N2_K()}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp new file mode 100755 index 000000000000..c929956124ff --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -0,0 +1,426 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_mx_pipeline_base +{ + using ComputeTypeA = ADataType; + using ComputeTypeB = BDataType; + using AccType = float; // for now only support V_MFMA_SCALE_F32 + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. + static constexpr index_t WaveSize = 64; + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + // static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = + BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {}); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t AMmaKStride = KPack; + static constexpr index_t BMmaKStride = KPack; + + // store rows/cols into thread registers in chunks of 16 for FP8 + // e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47] + // or in chunks of 32 / APackedSize for FP6/FP4 + static constexpr index_t KThreadChunk = (APackedSize == 1) ? 16 : 32 / APackedSize; + + static_assert(APackedSize == BPackedSize, "APackedSize must be equal to BPackedSize for now"); + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KPerInnerLoop = KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + // Hardcode to 2, for better 8-bit access pattern + + static constexpr index_t MXdlPack = 2; + static constexpr index_t NXdlPack = 2; + static constexpr index_t KXdlPack = 2; + + using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< // + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + A_K1, + B_K1, + A_K1, + B_K1, + MRepeat, + NRepeat, + MPerXDL, + NPerXDL, + xdlops_gemm.KPerXdlops, + (packed_size_v > 1 || packed_size_v > 1)>; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_unmerge_transform(make_tuple(MRepeat / MXdlPack, MWaves, MXdlPack, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2, 3>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_unmerge_transform(make_tuple(NRepeat / NXdlPack, NWaves, NXdlPack, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2, 3>{})); + + // We pack 2 mfma in M/N direction, so we need to divide by 2 + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0 / MXdlPack, waveId_m, m0 % MXdlPack, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0 / NXdlPack, waveId_n, n0 % NXdlPack, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple5 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(), + Tuple5 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl, packed mfma + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + I1, + Number{}, + Number{}, + M0, + M1, + M2, + N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl_packed mfma + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( + c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k; + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, I1, Number{}, Number{}, Number{})); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, I1, Number{}, Number{}, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp new file mode 100755 index 000000000000..bfb081330c51 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp" + +namespace ck { + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmWmmaops_pipeline_v1{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmWmmaops_pipeline_v3{}; + } + else + { + static_assert(false, "BlockGemmPipeline configuration is not available"); + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp new file mode 100755 index 000000000000..31c4729760d3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 32; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerWmma); + + static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; + static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_WMMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / + (BlockSize / WaveSize) / + (MPerWmma * NPerWmma * KPerWmma); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerWmma: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + KPerWmma); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C WMMA inst: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d, %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_WMMA_Inst_Num, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp new file mode 100755 index 000000000000..d46c5b737d63 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -0,0 +1,381 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmWmmaops_pipeline_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I5 = Number<5>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = 32; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + +#if defined(__gfx12__) + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; +#else + static constexpr index_t A_KRow = 1; + static constexpr index_t B_KRow = 1; +#endif + + static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5); + + static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!"); + static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!"); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t KRepeat = KPerBlock / KPack; + + static constexpr auto WmmaK = Number{}; + + using HotLoopInstList = + ck::BlockwiseGemmWmmaops_pipeline_hotloop_inst; + + StaticBufferTupleOfVector + c_thread_buf_; + + struct Empty + { + __device__ Empty(){}; + template + __device__ void GlobalLoad(bool cond) + { + ignore = NBuffer; + ignore = cond; + } + }; + + template + struct BScale + { + __device__ BScale(GridDesc b_scale_grid_desc_, + ThreadCopy b_scale_thread_copy_, + GridBuffer b_scale_grid_buf_) + : b_scale_thread_copy(b_scale_thread_copy_), + b_scale_grid_desc(b_scale_grid_desc_), + b_scale_grid_buf(b_scale_grid_buf_){}; + + static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; + + static constexpr auto b_scale_thread_desc = BScaleThreadDesc{}; + + static constexpr auto b_scale_thread_copy_step = + make_tuple(make_multi_index(NWaves * NPerWmma, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK)); + + template + __device__ void GlobalLoad(bool cond) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, Number<0>{}), + b_scale_thread_bufs(Number{})); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(cond) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + } + + ThreadCopy b_scale_thread_copy; + GridDesc b_scale_grid_desc; + GridBuffer b_scale_grid_buf; + StaticallyIndexedArray{}> b_scale_thread_bufs; + }; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + +#if defined(__gfx12__) + const auto wmma_krow = wmma_gemm.GetSubGroupId(); +#else + const auto wmma_krow = 0; +#endif + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmWmmaops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding WMMA and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AWmmaTileDesc::IsKnownAtCompileTime() && + BWmmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize"); + + static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 && + NPerBlock % (NPerWmma * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1; + + protected: + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I0, + I0, + I1)); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I0, + I0, + I1)); + + // C[M, N, NumRegWmma] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + using BThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp new file mode 100755 index 000000000000..f25648efa64a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -0,0 +1,705 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmWmmaops_pipeline_v1 +{ +}; + +template +struct BlockwiseGemmWmmaops_pipeline_v1 + : BlockwiseGemmWmmaops_pipeline_base + +{ + using Base = BlockwiseGemmWmmaops_pipeline_base; + using Base::I0; + + using Base::A_K1; + using Base::A_KRow; + using Base::B_K1; + using Base::B_KRow; + using Base::KRepeat; + using Base::WmmaK; + + using Base::wmma_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + using Base::GetCThreadBuffer; + using Base:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + + using Base::a_block_desc_k0_m0_m1_m2_k1; + using Base::b_block_desc_k0_n0_n1_n2_k1; + + using typename Base::Empty; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; } + + static TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto blockwise_gemm_func = [&]() { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + blockwise_gemm_func(); + + block_sync_lds(); + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + blockwise_gemm_func(); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmWmmaops_pipeline_v1 + : BlockwiseGemmWmmaops_pipeline_base + +{ + using Base = BlockwiseGemmWmmaops_pipeline_base; + using Base::I0; + using Base::I1; + + using Base::A_K1; + using Base::A_KRow; + using Base::B_K1; + using Base::B_KRow; + using Base::KRepeat; + using Base::WmmaK; + + using Base::wmma_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + using Base::GetCThreadBuffer; + using Base:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + + using Base::a_block_desc_k0_m0_m1_m2_k1; + using Base::b_block_desc_k0_n0_n1_n2_k1; + + using typename Base::Empty; + + static constexpr index_t NumKClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1); + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; } + + static TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto blockwise_gemm_func = [&]() { + static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) { + static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, + m0, + I0, + I0, + I0, + I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs(I0)[Number< + n0 * BScaleStruct::num_scale_k_block + + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } + }); + + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, + // but except the first, as we can shorten non-MAC cluster a bit and there's no + // observable negative impact. The desired effect is waves in a workgroup + // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC + // resource from other workgroups and reducing the chance of latency hiding by + // waiting for the rest of the workgroup at the eventual sync point. + if constexpr(k0_offset != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0_inner, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0_inner, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard. + // B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts. + // It is performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 && + n0 == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + }; + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + blockwise_gemm_func(); + + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + blockwise_gemm_func(); + } + } + + protected: + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I0, + I0, + I1)); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + Number{}, + I1, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I0, + I0, + I1)); + + using AThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + using BThreadCopy = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp new file mode 100755 index 000000000000..8fed23d151e6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -0,0 +1,501 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmWmmaops_pipeline_v3 +{ +}; + +template +struct BlockwiseGemmWmmaops_pipeline_v3 + : BlockwiseGemmWmmaops_pipeline_base +{ + using Base = BlockwiseGemmWmmaops_pipeline_base; + using Base::I0; + + using Base::A_K1; + using Base::A_KRow; + using Base::B_K1; + using Base::B_KRow; + using Base::KRepeat; + using Base::WmmaK; + + using Base::wmma_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base:: + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + using Base::GetCThreadBuffer; + using Base:: + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs; + + using Base::a_block_desc_k0_m0_m1_m2_k1; + using Base::b_block_desc_k0_n0_n1_n2_k1; + + using typename Base::Empty; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // TODO: Calculation of the number of instructions may require changes for WMMA + /* + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num; + + constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_wmma_rate = + (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_wmma = + (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate; + constexpr auto num_dsread_b_wmma = + (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma); + constexpr auto num_wmma_per_issue = + num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA + }); + + // stage 2 + static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >= + ds_read_a_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_wmma - 1) * + ds_read_a_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + + static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >= + ds_read_b_wmma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_wmma - 1) * + ds_read_b_wmma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA + }); + */ + } + + template + __device__ inline void LocalLoad(ABlockBuffer& a_block_buf, + AThreadBuffer& a_thread_buf, + BBlockBuffer& b_block_buf, + BThreadBuffer& b_thread_buf, + BScaleStruct& b_scale_struct) const + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + + if constexpr(ck::is_same_v) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, m0, k0, I0, I0, Number{}))>{}]; + }); + static_for<0, KPack / B_KRow, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, n0, k0, I0, I0, Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp new file mode 100755 index 000000000000..438d7d8ac3c1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp @@ -0,0 +1,994 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +// Double LDS buffer +// Prefetech 2 stage +// Local prefetch 1 stage + +namespace ck { + +template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num); + } +}; + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_pipeline_v4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ + BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + + // HotLoopInstList::Print(); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __device__ static constexpr auto HotLoopScheduler() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_buffer_load_inst = + HotLoopInstList::A_Buffer_Load_Inst_Num + HotLoopInstList::B_Buffer_Load_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_buffer_load_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA + }); + } + + template + __device__ static constexpr auto TailScheduler() + { + } + + template <> + __device__ constexpr auto TailScheduler<1>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst = + HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num; + ; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_write_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier( + 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA + }); + } + + template <> + __device__ constexpr auto TailScheduler<2>() + { + // schedule + constexpr auto num_ds_read_inst = + HotLoopInstList::A_LDS_Read_Inst_Num + HotLoopInstList::B_LDS_Read_Inst_Num; + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto num_issue = num_ds_read_inst; + + static_for<0, num_issue, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA + }); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + // Inst List: + // ds_read_b128: 16 + // ds_write_b128: 8 + // buffer_load_dwordx4: 16 + // v_mfma: 0 + // ------------------------------------------------------------------------------------------- + + // Global prefetch 1th, Fill Ping LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Local prefetch 1th, Fill Ping Reg + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Global prefetch 2th, Fill Pong LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3rd + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + // ------------------------------------------------------------------------------------------- + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{})); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 2; + } while(i < (num_loop - 3)); + } + + // tail + if constexpr(TailNum == 3) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{})); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{})); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<1>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + using PongP2 = Number<0>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP2{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP2{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP2{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP2{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PongP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PongP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == 2) + { + using PingP1 = Number<0>; + using PongP1 = Number<1>; + // MFMA: Ping Reg + // DS_WRITE: To Ping LDS + // DS_READ: Pong LDS to Pong Reg + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(PongP1{}), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(PongP1{})); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(PongP1{}), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(PongP1{})); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP1{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP1{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + TailScheduler<2>(); + __builtin_amdgcn_sched_barrier(0); + + // ------------------------------------------------------------------------------------------- + using PingP2 = Number<1>; + // MFMA: Pong Reg + // DS_WRITE: To Pong LDS + // DS_READ: Ping LDS to Ping Reg + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[PingP2{}][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[PingP2{}][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + // 64 v_mfma + __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp new file mode 100755 index 000000000000..5997b4779079 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp" + +namespace ck { + +template +constexpr auto BlockGemmABScalePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_ab_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2_ab_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_ab_scale{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp new file mode 100755 index 000000000000..7a565fbaa752 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp @@ -0,0 +1,547 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_thread_dequant_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + StaticallyIndexedArray{}> b_thread_dequant_bufs; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeDataType, + decltype(b_block_desc_n0_n1_k0_k1), + decltype(b_block_desc_n0_n1_k0_k1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp new file mode 100755 index 000000000000..8b227a8aa1b6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -0,0 +1,932 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KRepeat; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + using Base::MWaves; + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ static constexpr auto HotLoopScheduler(Stage stage) + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto staged_num_ds_read_inst_a = + ck::math::integer_divide_ceil(num_ds_read_inst_a, MRepeat); + constexpr auto staged_num_mfma = ck::math::integer_divide_ceil(num_mfma, MRepeat); + + constexpr auto staged_num_mfma_per_ds_read_a = + ck::math::integer_divide_ceil(staged_num_mfma, staged_num_ds_read_inst_a); + + if constexpr(stage.value == 0) + { + constexpr auto staged_num_buffer_load_b_per_ds_read_a = + ck::math::integer_divide_ceil(num_buffer_load_inst_b, staged_num_ds_read_inst_a); + constexpr auto staged_num_mfma_per_buffer_load_b = + ck::math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_b); + // B global + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) { + ignore = ibuf_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 1) + { + constexpr auto staged_num_mfma_per_ds_write_a = + math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + + // A local write + static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 2) + { + constexpr auto staged_num_mfma_per_buffer_load_a = + math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a; + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + } + }); + + __builtin_amdgcn_sched_barrier(0); + } + else + { + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + } + + template + __device__ static constexpr auto EpilogueScheduler_1(Stage stage) + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + if constexpr(stage.value == 0) + { + constexpr auto staged_num_buffer_load_b_per_ds_read_a = + num_buffer_load_inst_b / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_buffer_load_b = + staged_num_mfma / num_buffer_load_inst_b; + // B global + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { + ignore = ibuf_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 1) + { +#if 0 + constexpr auto staged_num_ds_write_a_per_ds_read_a = + num_ds_write_inst_a / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_ds_write_a = staged_num_mfma / num_ds_write_inst_a; + // A local write + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_ds_write_a_per_ds_read_a, 1>{}([&](auto idswrite_inst) { + ignore = idswrite_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + }); + + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_ds_write_a_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); +#elif 1 + constexpr auto staged_num_mfma_per_ds_write_a = + math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + + // A local write + static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + }); +#endif + __builtin_amdgcn_sched_barrier(0); + } + else + { + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + } + + __device__ static constexpr auto EpilogueScheduler_2() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto b_thread_dequant_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_dequant_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(I0, I0, I0, k0, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(I0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + if constexpr(m0.value == 0) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + } + else if constexpr(m0.value == 1) + { + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + } + else if constexpr(m0.value == 2) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + } + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == MRepeat - 1) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + I0), + a_thread_buf); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 1 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + I0), + a_thread_buf); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(mfma_reg_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(mfma_reg_buf)); + } + + HotLoopScheduler(m0); + }); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + if constexpr(m0.value == 0) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + } + else if constexpr(m0.value == MRepeat - 1) + { + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + } + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == MRepeat - 1) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 1) % MRepeat>{}, I0, I0, k0, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + } + + EpilogueScheduler_1(m0); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value != (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, I0, I0, k0, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + + EpilogueScheduler_2(); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value != (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, I0, I0, k0, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), + a_thread_buf); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + + EpilogueScheduler_2(); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeDataType, + decltype(b_block_desc_n0_n1_k0_k1), + decltype(b_block_desc_n0_n1_k0_k1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp new file mode 100755 index 000000000000..29750b8baa87 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -0,0 +1,621 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_thread_dequant_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + StaticallyIndexedArray{}> b_thread_dequant_bufs; + StaticallyIndexedArray{}> + b_thread_dequant_bufs_up; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I0)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I0)); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up + [mfma_reg_buf][Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(local_read_buf)); + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(local_read_buf)); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); + }); + }); + // B VGPR->VGPR dequant + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs(I1)); + + b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1), + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_dequant_bufs_up(I1)); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_dequant_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_dequant_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic< + BDataType, + ComputeDataType, + decltype(b_block_desc_n0_n1_k0_k1), + decltype(b_block_desc_n0_n1_k0_k1), + tensor_operation::element_wise::PassThrough, + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + KPack>; + + const PassThrough b_element_op{}; + BThreadDequantCopy b_thread_dequant_copy_{b_element_op}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp new file mode 100755 index 000000000000..fe89e700c415 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp @@ -0,0 +1,582 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::c_thread_desc_; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = + HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 64) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // if constexpr(i.value < num_buffer_load_inst_a) { + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // } + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp new file mode 100755 index 000000000000..c76be74e525a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp @@ -0,0 +1,952 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + + template + __device__ static constexpr auto EpilogueScheduler_1(Stage stage) + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_buffer_load_inst_b = + MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + if constexpr(stage.value == 0) + { + constexpr auto staged_num_buffer_load_b_per_ds_read_a = + num_buffer_load_inst_b / staged_num_ds_read_inst_a; + constexpr auto staged_num_mfma_per_buffer_load_b = + staged_num_mfma / num_buffer_load_inst_b; + // B global + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + + static_for<0, staged_num_buffer_load_b_per_ds_read_a, 1>{}([&](auto ibuf_inst) { + ignore = ibuf_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(stage.value == 1) + { + constexpr auto staged_num_mfma_per_ds_write_a = + math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a); + + constexpr auto stage_more_mfma = + staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a; + + // A local write + static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) { + if constexpr(i_inst.value < stage_more_mfma) + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + else + { + if(i_inst.value < staged_num_ds_read_inst_a) + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write + } + } + }); + __builtin_amdgcn_sched_barrier(0); + } + else + { + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier( + 0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + } + + __device__ static constexpr auto EpilogueScheduler_2() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + + constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2; + + constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat; + constexpr auto staged_num_mfma = num_mfma / MRepeat; + + constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a; + + // A local Read + static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) { + ignore = i_inst; + __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + }); + + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == MRepeat - 2) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + HotLoopScheduler(); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == MRepeat - 1) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp new file mode 100755 index 000000000000..b3b3d312c7ef --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp @@ -0,0 +1,1200 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< + BlockGemmPipelineScheduler::Intrawave, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::A_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t LocalPrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2; + static constexpr auto async_vmcnt = num_buffer_load_a_scale + num_buffer_load_b_scale + + HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + constexpr auto num_buffer_load_stage1 = + num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); + + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; + + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // B Gate and Up + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + // Gate and Up scale + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + index_t num_loop) const + { + ignore = b_block_bufs; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf_up = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf_up = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); + b_blockwise_copy_up.Run( + b_grid_desc, b_grid_buf_up, b_block_desc, b_block_origin_idx, b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_gate + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple( + I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + constexpr index_t SwitchM = MRepeat - LocalPrefetchStages; + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc, + b_block_origin_idx, + b_thread_bufs_up(scale_mem_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_gate + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(scale_mem_buf)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [scale_comp_buf][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up + [scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, + a_grid_buf, + a_block_desc, + a_block_bufs(scale_comp_buf)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + } + + constexpr auto lds_buf = + m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); + b_blockwise_copy_up.Run( + b_grid_desc, b_grid_buf_up, b_block_desc, b_block_origin_idx, b_thread_bufs_up(I1)); + + // Prefetch a_scales_up + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales_gate + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I1)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * Up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + // B Gate scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + // B Up scale + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation A * Gate + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + // MFMA accumulation A * up + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up.template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack] + // Order: 1 0 3 2 4 + static constexpr auto ARegBuf = 2; + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + // using Base::a_thread_copy_; + // using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp new file mode 100755 index 000000000000..6789d26a4592 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXBPreshufflePipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(GUFusion) + { + return nullptr; + } + else + { + return nullptr; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v3< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp new file mode 100755 index 000000000000..2b936c8d25a1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp @@ -0,0 +1,1041 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::A_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t LocalPrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + static constexpr auto async_vmcnt = + num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; + static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_stage1 = + num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); + + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; + + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_bufs; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple( + I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + constexpr index_t SwitchM = MRepeat - LocalPrefetchStages; + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, + a_grid_buf, + a_block_desc, + a_block_bufs(scale_comp_buf)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + } + + constexpr auto lds_buf = + m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack] + // Order: 1 0 3 2 4 + static constexpr auto ARegBuf = 2; + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + // using Base::a_thread_copy_; + // using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp new file mode 100755 index 000000000000..c6966011b4c9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp @@ -0,0 +1,255 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp" + +namespace ck { + +template +constexpr auto BlockGemmBPreshufflePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(std::is_same::value) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}; + } + } + else + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v2{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); + if constexpr(std::is_same::value) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + + return BlockwiseGemmXdlops_pipeline_bpreshuffle_v3{}; + } + } + else + { + return BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp new file mode 100755 index 000000000000..d8f11572a879 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -0,0 +1,525 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 128) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + // loop prefetch copy + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + // tail Local Prefetch A1 + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp new file mode 100755 index 000000000000..601756be442f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v2.hpp @@ -0,0 +1,572 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 2 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + // B global + A local + static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A + }); + + static_for<0, num_buffer_load_inst_b / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read B + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read A + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + // static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + // ignore = i; + // __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + // }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1, B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(I0)); + }); + }); + }); + + // Local prefill A2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1); + + // // Global prefetch A3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + block_sync_lds(); + + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // main loop A matrix prefetch + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_buf)); + }); + }); + }); + + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf.At(mfma_reg_buf), mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 3)); + } + // tail + + auto ReadWriteCompFunc = [&](auto mfma_reg, auto local_read_reg) { + block_sync_lds(); + + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_reg)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // tail prefetch A + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); + }); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(mfma_reg), mfma_reg); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + auto ReadCompFunc = [&](auto mfma_reg, auto local_read_reg) { + block_sync_lds(); + + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_reg)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(local_read_reg), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_bufs(local_read_reg)); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + auto CompFunc = [&](auto mfma_reg) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::Even) + { + ReadCompFunc(I0, I1); + CompFunc(I1); + } + else if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0, I1); + ReadCompFunc(I1, I0); + CompFunc(I0); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp new file mode 100755 index 000000000000..6f0404a1caae --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v3.hpp @@ -0,0 +1,769 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +#define DS_READ_A_PREFETCH_STAGES 2 + +template +constexpr auto compute_stage_loads(T total_loads, T stages) +{ + return std::make_pair((total_loads + stages - 1) / stages, // ceil + total_loads / stages // floor + ); +} + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / MRepeat; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / MRepeat; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto total_buffer_loads = num_buffer_load_inst_a + num_buffer_load_inst_b; + constexpr auto stages_available = MRepeat - DS_READ_A_PREFETCH_STAGES; + + constexpr auto stage_loads = compute_stage_loads(total_buffer_loads, stages_available); + + constexpr auto buffer_load_perstage_more = stage_loads.first; + constexpr auto buffer_load_perstage_less = stage_loads.second; + + constexpr auto buffer_load_stages_more = total_buffer_loads % stages_available; + + constexpr auto buffer_b_heavy_loads = buffer_load_perstage_more * buffer_load_stages_more; + constexpr auto buffer_b_remaining = + num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more; + + constexpr auto buffer_load_b_stages = + buffer_b_heavy_loads > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + buffer_b_remaining / buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - DS_READ_A_PREFETCH_STAGES - buffer_load_b_stages; + + static_assert(buffer_load_a_stages > 0, + "The buffer load a stages should always have a value over 0."); + + constexpr auto buffer_load_issue_point_interval_more = + math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_more); + constexpr auto buffer_load_issue_point_interval_less = + buffer_load_perstage_less == 0 + ? INT32_MAX + : math::integer_divide_ceil(num_mfma_perstage, buffer_load_perstage_less); + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == 0)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == 0))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_LDS_WRITE, 1, 0); + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_VMEM, 1, 0); + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(SCHED_GROUP_MFMA, 1, 0); + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + SCHED_GROUP_LDS_READ, ds_read_a_mfma_rate, 0); + } + }); + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + // Global prefetch A1 B1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + static_for<0, DS_READ_A_PREFETCH_STAGES, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<0>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + HotLoopScheduler(); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number<0>{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp new file mode 100755 index 000000000000..48fe5131d2b8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp" + +namespace ck { + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return BlockwiseGemmXdlops_pipeline_v4_b_scale{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) + { + return BlockwiseGemmXdlops_pipeline_v5{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp new file mode 100755 index 000000000000..9296b8136f06 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -0,0 +1,399 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_pipeline_base +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs. + static constexpr index_t WaveSize = 64; + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = + BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {}); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t AMmaKStride = KPack; + static constexpr index_t BMmaKStride = KPack; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KPerInnerLoop = KPack; + + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) || + (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64)) + ? 2 + : 1; + else + return 1; + }(); + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = + ck::BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateAThreadOriginDataIndex6D() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmXdlops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; } + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp new file mode 100755 index 000000000000..818439fddf74 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp" +namespace ck { + +template +constexpr auto BlockGemmBlockScaleBPreshufflePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } +#if 0 + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v2< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } +#endif + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); + return BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp new file mode 100755 index 000000000000..8e2922e2ce94 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v1.hpp @@ -0,0 +1,864 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::MWaves; + using Base::NWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + // __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + // __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( + [&](auto t) { + using pk_fma_type = + typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec + .template AsType()[Number<0>{}], + c_thread_buf + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp new file mode 100755 index 000000000000..cc4c5a2c36d9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v3.hpp @@ -0,0 +1,1090 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t LocalPrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - (LocalPrefetchStages - 1))); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - (LocalPrefetchStages - 1))); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - (LocalPrefetchStages - 1))) * + ((num_total_stages - (LocalPrefetchStages - 1))); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - (LocalPrefetchStages - 1) - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) && + (imfma < (num_mfma_perstage - 1))) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac + }); + // Scale load, 1B + if constexpr(i.value == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // __builtin_amdgcn_sched_barrier(0); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) && + (imfma < (num_mfma_perstage - 1))) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac + }); + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // __builtin_amdgcn_sched_barrier(0); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, (LocalPrefetchStages - 1), 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) && + (imfma < (num_mfma_perstage - 1))) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + // __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac + }); + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + // __builtin_amdgcn_sched_barrier(0); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + // StaticallyIndexedArray{}> c_scale_thread_bufs; + + // Global prefetch A1 B1, AScale1 BScale1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_bufs(I0)); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // Global prefetch A2, AScale2 BScale2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + +#if 1 + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_bufs(I0)); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); +#endif + // Initialize C + c_thread_buf.Clear(); + + // Double register buffer for non-scaled gemm computation + // 1. Reduce register pressure + // 2. Decouple the dependency between mfma instruction and scale-fma instruction following. + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + +#if 1 + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + + // Fill first mfma buffer + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); +#endif + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(local_read_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_bufs(local_read_buf)); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + constexpr auto b_local_buf_id = + Number{}; + + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs + [b_local_buf_id][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + // We have to 1 stage early sync the lds for workaround the compiler + // limitation + if constexpr(m0.value == (MRepeat - LocalPrefetchStages - 1)) + { + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= (MRepeat - LocalPrefetchStages) + ? local_read_buf + : mfma_reg_buf; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages + + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; + }); + + // We need new compiler to enable this feature + // HotLoopScheduler(); + // __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + + constexpr auto b_local_buf_id = + Number<0 ^ ((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{}; + + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[b_local_buf_id][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - LocalPrefetchStages)) + { + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= (MRepeat - LocalPrefetchStages) ? I1 : I0; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + LocalPrefetchStages) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(Number{}), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + }); + + // HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + + if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1)))) + { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + } + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + LocalPrefetchStages + HotloopLocalBufSwitch) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto mfma_buf_offset = + ((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + constexpr auto scale_buf_offset = + ((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops(); + + constexpr auto a_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat; + constexpr auto b_local_buf_offset = + ((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat; + + if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1)))) + { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference( + Number{})); + }); + } + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(Number<(m0 + LocalPrefetchStages) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp new file mode 100755 index 000000000000..1608506b40f4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp @@ -0,0 +1,1036 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::MWaves; + using Base::NWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = + HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves * 2; + constexpr auto mfma_interleave = MPerXDL == 32 ? 1 : 2; + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + if constexpr(MPerBlock >= 128 && NPerBlock >= 64) + { + __builtin_amdgcn_sched_group_barrier(0x008, 2 * mfma_interleave, 0); + } + else + { + __builtin_amdgcn_sched_group_barrier(0x008, mfma_interleave, 0); + } + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, MPerXDL == 32 ? num_ds_read_inst_a / 2 : num_ds_read_inst_a, 1>{}( + [&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, MPerXDL == 32 ? 2 : 1, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + // __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf_up = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf_up = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf_up); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // __builtin_amdgcn_sched_barrier(0); + + constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; + }); + }); + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf_up); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + StaticBufferTupleOfVector + c_thread_buf_per_scale_up; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + // __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference( + Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( + [&](auto t) { + using pk_fma_type = + typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec + .template AsType()[Number<0>{}], + c_thread_buf + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale_up + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up + .template AsType()[Number<0>{}], + c_thread_buf_up + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; + }); + }); + }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf_up); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + c_scale_thread_buf_up(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf_up[Number{}]; + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + vector_type c_scale_thread_vec_up; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[Number{}]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp new file mode 100755 index 000000000000..30d6d4f81209 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp @@ -0,0 +1,1203 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + // Scale load, 1B + if constexpr(i.value == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Scale load, 1A + if constexpr(imfma == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs_up; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf_up = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + // StaticallyIndexedArray{}> c_scale_thread_bufs; + + // Global prefetch A1 B1, AScale1 BScale1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf_up(m0) = + a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // Global prefetch A2, AScale2 BScale2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + + // Double register buffer for non-scaled gemm computation + // 1. Reduce register pressure + // 2. Decouple the dependency between mfma instruction and scale-fma instruction following. + StaticBufferTupleOfVector + c_thread_buf_per_scale; + StaticBufferTupleOfVector + c_thread_buf_per_scale_up; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(local_read_buf)); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(local_read_buf)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(local_read_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs_up(local_read_buf)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up + .template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; + c_scale_thread_buf_up(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs_up[mfma_reg_buf][I0]; + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + b_blockwise_copy_up.Run(b_grid_desc, + b_grid_buf_up, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs_up(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + c_scale_thread_buf_up(m0) = + a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs_up[I0][I0]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + vector_type c_scale_thread_vec_up; + c_scale_thread_vec_up.template AsType()(Number<0>{}) = + c_scale_thread_buf_up[m0]; + c_scale_thread_vec_up.template AsType()(Number<1>{}) = + c_scale_thread_buf_up[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_bufs_up[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec_up.template AsType(), + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale_up.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec_up.template AsType()[Number<0>{}], + c_thread_buf_up.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp new file mode 100755 index 000000000000..e7c061bd974c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp" +namespace ck { + +template +constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } +#if 0 + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v2< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } +#endif + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3"); + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< + BlkGemmPipeSche, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp new file mode 100755 index 000000000000..598b69cd6188 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -0,0 +1,854 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::MWaves; + using Base::NWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num; + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves; + + // B global + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A global + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + }); + + // A local + static_for<0, num_ds_read_inst_a / 2, 1>{}([&](auto i) { + ignore = i; + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x100, 2, 0); // DS read + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + // __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + __builtin_amdgcn_sched_barrier(0); + + constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + + // __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number< + b_thread_desc_.CalculateOffset(make_tuple( + n0, + I0, + kscale0 * KRepeat / num_scale_k_block + k0, + ik))>{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}( + [&](auto t) { + using pk_fma_type = + typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = + __builtin_elementwise_fma( + c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec + .template AsType()[Number<0>{}], + c_thread_buf + .GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_buf); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + block_sync_lds(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + + // __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + vector_type c_scale_thread_vec; + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[Number{}]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[Number{}]; + + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + }); + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp new file mode 100755 index 000000000000..6db02d1dd751 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp @@ -0,0 +1,1070 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3< + BlockGemmPipelineScheduler::Intrawave, + BlockSize, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MScaleBlock, + NScaleBlock, + KScaleBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack> : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::KGroup; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MWaves; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + template + __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&) + { + constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); + constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); + constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); + constexpr index_t K2 = KPack / KGroup; + constexpr index_t K1 = 64 / NPerXDL; + constexpr index_t K0 = KRepeat * KGroup; + + return transform_tensor_descriptor( + TileDesc_M0_M1_M2_K{}, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 = + MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k); + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + static_assert(num_buffer_load_inst_a == num_ds_write_inst_a); + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = MRepeat; + + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = math::integer_divide_ceil( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = math::integer_divide_floor( + (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2)); + + constexpr auto buffer_load_stages_more = + (num_buffer_load_inst_a + num_buffer_load_inst_b) - + math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b), + (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_b_stages = + buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b + ? num_buffer_load_inst_b / buffer_load_perstage_more + : (buffer_load_stages_more + + (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) / + buffer_load_perstage_less); + + constexpr auto buffer_load_a_stages = + num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages; + + constexpr auto buffer_load_issue_point_b = 0; + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more + ? num_mfma_perstage / buffer_load_perstage_more + : 1; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less + ? num_mfma_perstage / buffer_load_perstage_less + : 1; + constexpr auto ds_write_issue_point = 0; + constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0; + + // B global read + static_for<0, buffer_load_b_stages, 1>{}([&](auto i) { + // Scale load, 1B + if constexpr(i.value == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(((i < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_b)) || + ((i >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_b))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // A global read + A local write + static_for<0, buffer_load_a_stages, 1>{}([&](auto i) { + // Scale load, 1A + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + ds_write_issue_point)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + ds_write_issue_point))) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_more == + buffer_load_issue_point_a)) || + (((i + buffer_load_b_stages) >= buffer_load_stages_more) && + (imfma % buffer_load_issue_point_interval_less == + buffer_load_issue_point_a))) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + // Scale load, 1A + if constexpr(imfma == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + // __builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma + }); + // __builtin_amdgcn_sched_barrier(0); + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + ignore = b_block_desc; + ignore = b_block_buf; + __builtin_amdgcn_sched_barrier(0); + static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + // StaticallyIndexedArray{}> c_scale_thread_bufs; + + // Global prefetch A1 B1, AScale1 BScale1 + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + __builtin_amdgcn_sched_barrier(0); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + + // Global prefetch A2, AScale2 BScale2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(I0)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Double register buffer for non-scaled gemm computation + // 1. Reduce register pressure + // 2. Decouple the dependency between mfma instruction and scale-fma instruction following. + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch A1 + block_sync_lds(); + static_for<0, 2, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + }); + +#if 0 + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + + // Fill first mfma buffer + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); +#endif + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf)); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(I0, I0), + a_scale_thread_bufs(local_read_buf)); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + } + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_bufs(local_read_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(local_read_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(mfma_reg_buf), + a_thread_desc_, + make_tuple( + Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) % + 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * + b_scale_thread_bufs[mfma_reg_buf][I0]; + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_n0_n1_k0_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value == (MRepeat - 2)) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else if constexpr(m0.value == (MRepeat - 1)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + else + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(Number<(m0 + 2) % MRepeat>{}, + I0, + I0, + Number{}, + I0, + I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + + HotLoopScheduler(); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0]; + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I1), + a_thread_desc_, + make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{}, + I0, + I0, + k0, + I0, + Number{}), + a_thread_buf); + }); + }); + } + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // // __builtin_amdgcn_sched_barrier(0); + } + else + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + vector_type c_scale_thread_vec; + c_scale_thread_vec.template AsType()(Number<0>{}) = + c_scale_thread_buf[m0]; + c_scale_thread_vec.template AsType()(Number<1>{}) = + c_scale_thread_buf[m0]; + + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) { + using pk_fma_type = typename vector_type::type; + + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()(t) = __builtin_elementwise_fma( + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[t], + c_scale_thread_vec.template AsType()[Number<0>{}], + c_thread_buf.GetVectorTypeReference(Number{}) + .template AsType()[t]); + }); + }); + + if constexpr(m0.value < (MRepeat - 2)) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, KGroup, 1>{}([&](auto kg0) { + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple( + Number{}, I0, I0, Number{}, I0, I0), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple( + Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + protected: + // MRepeat MWave MLane KRepeat KLane KPack + // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack + // Reduce the vgpr usage here. + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I2, I1, I1, Number{}, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()}; + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + + static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; + + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp new file mode 100755 index 000000000000..7d21c445048d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXBPreshufflePipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle{}; + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp new file mode 100755 index 000000000000..66d221691bcf --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp @@ -0,0 +1,1332 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2 * 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // A + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // Gate and Up + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_bufs, + BBlockBuffer& b_block_bufs_up, + const BBlockTransferStep& b_block_copy_step, + // C + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // A scale + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + // Gate and Up scale + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf_up = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf_up = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0)); + b_blockwise_copy_up.Run(b_grid_desc, b_grid_buf_up, b_block_desc, b_block_bufs_up(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_gate + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(3952); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs_up(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1)); + b_blockwise_copy_up.Run(b_grid_desc, b_grid_buf_up, b_block_desc, b_block_bufs_up(I1)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf)); + b_blockwise_copy_up.Run( + b_grid_desc, b_grid_buf_up, b_block_desc, b_block_bufs_up(scale_comp_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(scale_mem_buf)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + b_thread_vec_up.template AsType()( + ik) = b_thread_buf_up + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up + .template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + // block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(scale_mem_buf), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(scale_mem_buf), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs_up(scale_mem_buf), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I1)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs_up(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp new file mode 100755 index 000000000000..f2508d9cfabf --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp @@ -0,0 +1,1361 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2 * 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num * 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack * 2; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize * 2; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_b) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // A + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // B0/B1 + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + BBlockTransfer& b_blockwise_copy_up, + const BGridBuffer& b_grid_buf, + const BGridBuffer& b_grid_buf_up, + BBlockBuffer& b_block_buf, + BBlockBuffer& b_block_buf_up, + const BBlockTransferStep& b_block_copy_step, + // C + CThreadBuffer& c_thread_buf, + CThreadBuffer& c_thread_buf_up, + // A scale + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + // B0/B1 scale + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + BScaleThreadTransfer& b_scale_thread_copy_up, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleGridBuffer& b_scale_grid_buf_up, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf_up = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf_up = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs_up; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I0)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + c_thread_buf_up.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up); + b_blockwise_copy_up.RunRead(b_grid_desc, b_grid_buf_up); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(scale_mem_buf)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + vector_type + b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + b_thread_vec_up.template AsType()( + ik) = b_thread_buf_up + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up + .template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales_up + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy_up.Run(b_scale_grid_desc, + b_scale_grid_buf_up, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs_up(I1)); + + b_scale_thread_copy_up.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy_up.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy_up.RunWrite(b_block_desc, b_block_buf_up); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf_up, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf_up); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + vector_type b_scale_thread_vec_up; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec_up.template AsType()(s) = + b_scale_thread_bufs_up(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type b_thread_vec_up; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + b_thread_vec_up.template AsType()(ik) = + b_thread_buf_up[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec_up.template AsType(), + b_scale_thread_vec_up + .template AsType(), + c_thread_buf_up.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp new file mode 100755 index 000000000000..84b0eebb3149 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_gufusion_v3.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXNBSPipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(GUFusion) + { + return nullptr; + } + else + { + return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1{}; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}; + } + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp new file mode 100755 index 000000000000..32f624854387 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v1.hpp @@ -0,0 +1,664 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v1 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp new file mode 100755 index 000000000000..b48e464fee75 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp @@ -0,0 +1,1126 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_a) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma < num_dswrite_per_issue_b) + { + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp new file mode 100755 index 000000000000..f2a4eab39378 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXPipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if constexpr(GUFusion) + { + return nullptr; + } + else + { + return nullptr; + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if constexpr(GUFusion) + { + return BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3< + BlkGemmPipeSche, + ThreadBlockSize, + ScaleBlockSize, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + ATileDesc, + BTileDesc, + AMmaTileDesc, + BMmaTileDesc, + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack>{}; + } + else + { + return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3{}; + } + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp new file mode 100755 index 000000000000..bb4286b3f50c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp @@ -0,0 +1,1090 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3 + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 8 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 8 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + // __builtin_amdgcn_s_waitcnt(3952); + // block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(scale_mem_buf), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(scale_mem_buf), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp new file mode 100755 index 000000000000..52ab86b6d410 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp" + +namespace ck { +template +constexpr auto BlockGemmMXPipeline_Selector() +{ + + // Hardware MX GEMM pipeline + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_mx{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3_mx{}; + } + else + { + std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp new file mode 100755 index 000000000000..fae00704ae6b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp" + +namespace ck { + +template +constexpr auto BlockGemmPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + return BlockwiseGemmXdlops_pipeline_v2{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return BlockwiseGemmXdlops_pipeline_v3{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return BlockwiseGemmXdlops_pipeline_v4{}; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v5) + { + return BlockwiseGemmXdlops_pipeline_v5{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp new file mode 100755 index 000000000000..f597573dc2a4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -0,0 +1,731 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, + // but except the first, as we can shorten non-MAC cluster a bit and there's no + // observable negative impact. The desired effect is waves in a workgroup + // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC + // resource from other workgroups and reducing the chance of latency hiding by + // waiting for the rest of the workgroup at the eventual sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion by + // applying small delays to different wavefronts It is performed + // near the end of MAC cluster to minimize lgkmcnt penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp new file mode 100755 index 000000000000..ea4f5e4a286f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp @@ -0,0 +1,868 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_ab_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_ab_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t AMmaKStride = xdlops_gemm.K0PerXdlops * KPack; + static constexpr index_t BMmaKStride = xdlops_gemm.K0PerXdlops * KPack; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + // Force mfma not cross the scaleblock + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop == 1 ? TailNumber::Odd : TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = + CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + i += 1; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_buf(Number{}) = + a_scale_thread_buf[Number{}] * + b_scale_thread_buf[Number{}]; + }); + }); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto kscale0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset( + make_tuple(kscale0, m0, n0 * num_scale_n_block / NRepeat)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert( + c_scale_thread_buf[Number{}]); + }); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_desc_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp new file mode 100755 index 000000000000..4246f4a44e76 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp @@ -0,0 +1,403 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop, + index_t num_loop_per_scale) const + { + // assume kperblock = scaleblockk + ignore = num_loop_per_scale; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); + }); + }); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(b_scale_thread_buf[n0]); + }); + }); + }); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp new file mode 100755 index 000000000000..f4337745bf2d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp @@ -0,0 +1,672 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_mx +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_mx + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto AScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + static constexpr auto BScalesPerXdlopsRun = + (BPackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThreadA = + AScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + static constexpr auto ScalesPerXdlopsRunPerThreadB = + BScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefill 1 + __builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT, LDS, GDS, Constant and Message + block_sync_lds(); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + + // wait previous blockwise copy to finish + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + // load for next k loop + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, + "Must have at least one scale per Xdlops per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_buf); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + __builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT and LGKM_CNT + block_sync_lds(); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThreadA && + 0 < ScalesPerXdlopsRunPerThreadB, + "Must have at least one scale per Xdlops per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + }); + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp new file mode 100755 index 000000000000..a6b5e272ff27 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp @@ -0,0 +1,1154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t WgpPerCU = + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + + static constexpr index_t WgpPerCU = + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC + // cluster, but except the first, as we can shorten non-MAC cluster a bit + // and there's no observable negative impact. The desired effect is waves in + // a workgroup executing MAC in sync. This avoids some out-of-sync waves + // hijacking MAC resource from other workgroups and reducing the chance of + // latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp new file mode 100755 index 000000000000..0c030030fe28 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp @@ -0,0 +1,676 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2_ab_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_ab_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t WgpPerCU = + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop, + index_t num_loop_per_scale) const + { + // assume kperblock = scaleblockk + ignore = num_loop_per_scale; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Initialize C + c_thread_buf.Clear(); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + c_thread_buf_per_scale.Clear(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert(a_scale_thread_buf[m0]) * + type_convert(b_scale_thread_buf[I0]); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp new file mode 100755 index 000000000000..69002d79627b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_b_scale.hpp @@ -0,0 +1,1248 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Maximum Global Memory throughput pipeline with >=32KB data in fly +// GlobalPrefetchStages: >=2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t WgpPerCU = + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +template +struct BlockwiseGemmXdlops_pipeline_v2_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + static constexpr index_t KRepeat = KPerThread / KPerInnerLoop; + + static constexpr index_t WgpPerCU = + (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1; + static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil( + 32768 / WgpPerCU, + (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); + static constexpr index_t PrefetchStages = + FullMemBandPrefetchStages >= 2 + ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8 + : 2; + + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = PrefetchStages; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % PrefetchStages == 1) + { + return TailNumber::One; + } + else if(num_loop % PrefetchStages == 2) + { + return TailNumber::Two; + } + else if(num_loop % PrefetchStages == 3) + { + return TailNumber::Three; + } + else if(num_loop % PrefetchStages == 4) + { + return TailNumber::Four; + } + else if(num_loop % PrefetchStages == 5) + { + return TailNumber::Five; + } + else if(num_loop % PrefetchStages == 6) + { + return TailNumber::Six; + } + else if(num_loop % PrefetchStages == 7) + { + return TailNumber::Seven; + } + else + { + return TailNumber::Full; + } + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + const BScaleGridDesc& b_scale_grid_desc, + // BScaleThreadCopy + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + ignore = num_loop_per_scale; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch [2, PrefetchStages] + static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + + auto c_thread_buf_per_scale = remove_cvref_t(); // need? + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) { + // ------------------------------------------------------------------------------------------- + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC + // cluster, but except the first, as we can shorten non-MAC cluster a bit + // and there's no observable negative impact. The desired effect is waves in + // a workgroup executing MAC in sync. This avoids some out-of-sync waves + // hijacking MAC resource from other workgroups and reducing the chance of + // latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from + // blockwise_gemm is moved here B) reduce VMEM FIFO congestion + // by applying small delays to different wavefronts It is + // performed near the end of MAC cluster to minimize lgkmcnt + // penalty + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) + // { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_scale_thread_copy.Run(b_scale_grid_desc, + // b_scale_grid_buf, + // b_scale_thread_desc, + // make_tuple(n0, I0), + // b_scale_thread_buf); + + // b_scale_thread_copy.MoveSrcSliceWindow( + // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + // }); + // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + // b_scale_thread_copy_step.At(Number<1>{})); + + // block_sync_lds(); + a_blockwise_copy.RunWrite( + a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + b_blockwise_copy.RunWrite( + b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{}); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + }); + i += PrefetchStages; + } while(i < (num_loop - PrefetchStages)); + } + + // tail + + auto LoopTailFunc = [&](auto tail_num) { + static_for<1, tail_num, 1>{}([&](auto iprefetch) { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + + // static_for<0, NRepeat, 1>{}([&](auto n0) { + // b_scale_thread_copy.Run(b_scale_grid_desc, + // b_scale_grid_buf, + // b_scale_thread_desc, + // make_tuple(n0, I0), + // b_scale_thread_buf); + + // b_scale_thread_copy.MoveSrcSliceWindow( + // b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + // }); + // b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + // b_scale_thread_copy_step.At(Number<1>{})); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch); + }); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + }; + + if constexpr(TailNum == TailNumber::One) + { + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + if constexpr(k0.value != 0 || KRepeat == 1) + { + __builtin_amdgcn_s_barrier(); + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + if constexpr(k0.value == KRepeat - 1 && + k_.value == KPerInnerLoop - KPack && + m0.value == MRepeat - 1 && n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + + // static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + // constexpr index_t c_offset = + // c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + // c_thread_buf(Number{}) += + // c_thread_buf_per_scale[Number{}] * + // type_convert(b_scale_thread_buf[n0]); + // }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + else if constexpr(TailNum == TailNumber::Two) + { + LoopTailFunc(Number<2>{}); + } + else if constexpr(TailNum == TailNumber::Three) + { + LoopTailFunc(Number<3>{}); + } + else if constexpr(TailNum == TailNumber::Four) + { + LoopTailFunc(Number<4>{}); + } + else if constexpr(TailNum == TailNumber::Five) + { + LoopTailFunc(Number<5>{}); + } + else if constexpr(TailNum == TailNumber::Six) + { + LoopTailFunc(Number<6>{}); + } + else if constexpr(TailNum == TailNumber::Seven) + { + LoopTailFunc(Number<7>{}); + } + else if constexpr(TailNum == TailNumber::Full) + { + LoopTailFunc(Number{}); + } + } + + protected: + // K->M loopover + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + I1)); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp new file mode 100755 index 000000000000..b5d6180ab3c2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp @@ -0,0 +1,463 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp new file mode 100755 index 000000000000..a7d22066acd6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp @@ -0,0 +1,623 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_ab_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_ab_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + const CScaleThreadDesc& c_scale_thread_desc, + CThreadBuffer& c_thread_buf, + // AScaleThreadCopy + const AScaleGridDesc& a_scale_grid_desc, + const AScaleThreadDesc& a_scale_thread_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const AScaleThreadTransferStep& a_scale_thread_copy_step, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num_loop + index_t num_loop) const + { + __builtin_amdgcn_sched_barrier(0); + static_assert(CScaleThreadDesc{}.GetLength(Number<0>{}) == 1, + "Pipeline v3 only support scaleblocksliceK=1"); + static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1, + "Pipeline v3 only support scaleblocksliceN=1"); + // assume kperblock = scaleblockk + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + auto c_scale_thread_buf = make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + }); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + StaticBufferTupleOfVector + c_thread_buf_per_scale; + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0]; + }); + + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, I0), + a_scale_thread_buf); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{})); + }); + + if constexpr(NumKBlockPerScale == 1) + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{})); + } + else + { + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{})); + } + + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(I0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + static_for<0, KRepeat, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); + }); + static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_buf[m0]); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp new file mode 100755 index 000000000000..3179a90b7fd5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_b_scale.hpp @@ -0,0 +1,530 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + // Separate this part? + // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) > + // sizeof(ComputeDataType) / sizeof(BDataType) + // ? sizeof(ComputeDataType) / sizeof(ADataType) + // : sizeof(ComputeDataType) / sizeof(BDataType); + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_mfma_per_issue = + num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // B scale buffer + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + constexpr auto num_scale_k_block = BScaleThreadDesc{}.GetLength(I1); + constexpr auto num_scale_krepeat = KRepeat / num_scale_k_block; + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_scale_thread_buf[Number{}], + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + if((i + 2) % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{})); + } + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_scale_thread_buf[Number{}], + b_thread_desc_, + make_tuple(n0, I0, k0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 1)); + } + // tail + if constexpr(TailNum == TailNumber::Full) + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp new file mode 100755 index 000000000000..fe7d84eda445 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp @@ -0,0 +1,1090 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_mx +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_mx + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_a_mfma = + (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_b_mfma = + (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma); + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * + ds_read_a_mfma_rate, + 0); // DS read + } + }); + + static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier(0x100, + num_ds_read_inst_b - (num_dsread_b_mfma - 1) * + ds_read_b_mfma_rate, + 0); // DS read + } + }); + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I0), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I1)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(scale_comp_buf)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(scale_comp_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()( + ik) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()( + ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference( + Number{})); + }); + }); + }); + }); + }); + }); + + // k indexes mapping to threads for 32x32x64: + // t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 + // __builtin_amdgcn_s_waitcnt(3952); + // block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(scale_mem_buf), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, + xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(scale_mem_buf), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + + __builtin_amdgcn_s_waitcnt(3952); + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k, + make_tuple(Number{}, + I0, + Number{}, + I0, + Number{}), + b_block_bufs(I1), + b_thread_desc_, + make_tuple(Number{}, + I0, + Number{}, + k, + Number{}), + b_thread_buf); + }); + }); + }); + + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto kxdl = ikxdl + k0 * KXdlPack; + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = typename vector_type< // + ComputeTypeA, + xdlops_gemm.K1PerXdlops / APackedSize>::type; + + using mfma_input_type_b = typename vector_type< // + ComputeTypeB, + xdlops_gemm.K1PerXdlops / BPackedSize>::type; + + using mfma_scale_input_type_a = typename vector_type< // + AScaleDataType, + a_scale_thread_vec_size>::type; + using mfma_scale_input_type_b = typename vector_type< // + BScaleDataType, + b_scale_thread_vec_size>::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec + .template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec + .template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp new file mode 100755 index 000000000000..629bbb316f93 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp @@ -0,0 +1,1154 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 2 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle + : BlockwiseGemmXdlops_mx_pipeline_base + +{ + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; + using Base::A_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_m3_k; + using Base::b_block_desc_n0_n1_n2_n3_k; + + using Base::AMmaKStride; + using Base::APackedSize; + using Base::BMmaKStride; + using Base::BPackedSize; + using Base::KThreadChunk; + + using Base::KXdlPack; + using Base::MXdlPack; + using Base::NXdlPack; + + using AccType = typename Base::AccType; + using Tuple5 = typename Base::Tuple5; + using ComputeTypeA = typename Base::ComputeTypeA; + using ComputeTypeB = typename Base::ComputeTypeB; + + static constexpr index_t PrefetchStages = 2; + static constexpr index_t LocalPrefetchStages = 2; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1; + + static constexpr auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack; + static constexpr auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack; + static constexpr auto async_vmcnt = + num_buffer_load_a_scale + num_buffer_load_b_scale + HotLoopInstList::B_Buffer_Load_Inst_Num; + static constexpr auto async_vmcnt_encoding = 3952 + async_vmcnt % 16 + async_vmcnt / 16 * 16384; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block + + //> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = + (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; + + using mx_scale_t = e8m0_bexp_t; + static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a; + static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd; + } + + __device__ static constexpr auto HotLoopScheduler() + { + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_stage1 = + num_buffer_load_inst_b + num_buffer_load_a_scale + num_buffer_load_b_scale; + + constexpr auto num_buffer_load_stage2 = num_buffer_load_inst_a; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize; + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + math::integer_divide_ceil(mfma_cycle - 8, 2 * ds_read_a_issue_cycle); + + // constexpr auto num_dsread_a_mfma = + // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + + constexpr auto num_total_stages = std::max(2, MRepeat); + + if constexpr(num_total_stages > 2) + { + // Group num_mfma_perstage num_ds_read_a_perstage + // since we want to reuse a local register buffer + constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages; + constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages; + + constexpr auto num_ds_read_a_mfma_perstage = + math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate); + + constexpr auto num_ds_read_a_prefetch_stages = 2; + + constexpr auto buffer_load_perstage_more = + math::integer_divide_ceil((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_less = + math::integer_divide_floor((num_buffer_load_stage1), (num_total_stages - 2)); + constexpr auto buffer_load_perstage_stage2 = + math::integer_divide_floor((num_buffer_load_stage2), 2); + + constexpr auto buffer_load_stages_more = + num_buffer_load_stage1 - + math::integer_divide_floor(num_buffer_load_stage1, (num_total_stages - 2)) * + ((num_total_stages - 2)); + + constexpr auto buffer_load_issue_point_interval_more = + num_mfma_perstage / buffer_load_perstage_more; + constexpr auto buffer_load_issue_point_interval_less = + num_mfma_perstage / buffer_load_perstage_less; + constexpr auto buffer_load_issue_point_interval_stage2 = + num_mfma_perstage / buffer_load_perstage_stage2; + + // Stage 1 + // global read more + static_for<0, buffer_load_stages_more, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + if constexpr(imfma % buffer_load_issue_point_interval_more == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // global read less + static_for<0, (num_total_stages - 2 - buffer_load_stages_more), 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_less == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + + // Stage 2, Sync + // lds synchronization, prefetch next loop local A + static_for<0, num_ds_read_a_prefetch_stages, 1>{}([&](auto /*i*/) { + static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(imfma % buffer_load_issue_point_interval_stage2 == 0) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage)) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, ds_read_a_mfma_rate, 0); // DS read + } + }); + }); + } + else + { + constexpr auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale + + num_buffer_load_b_scale; + constexpr auto num_dsread_a_mfma = math::integer_divide_ceil( + num_ds_read_inst_a, ds_read_a_mfma_rate); // how many mfma per dsread_a + + // stage 1 + constexpr auto num_mfma_stage1 = num_mfma_inst - num_dsread_a_mfma; + + constexpr auto mfma_perstage_more = + math::integer_divide_ceil(num_mfma_stage1, num_buffer_load_total); + constexpr auto mfma_perstage_less = + math::integer_divide_floor(num_mfma_stage1, num_buffer_load_total); + + constexpr auto mfma_stages_more = + num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total; + + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + if constexpr(i < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_a_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < + mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + static_for<0, num_buffer_load_b_scale, 1>{}([&](auto i) { + if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b + + num_buffer_load_a_scale) < mfma_stages_more) + { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + else + { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + }); + + // stage 2 + static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) { + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + }); + } + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + ignore = b_block_bufs; + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_scale_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I0)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // Local prefetch 1, sync the async load + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + static_for<0, LocalPrefetchStages, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple( + I0, I0, Number{}, I0, Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I1)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + __builtin_amdgcn_sched_barrier(0); + constexpr index_t SwitchM = MRepeat - LocalPrefetchStages; + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + auto LoopFunc = [&](auto scale_comp_buf, auto scale_mem_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc, + b_block_origin_idx, + b_thread_bufs(scale_mem_buf)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(scale_mem_buf)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0)); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(scale_mem_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0)); + + // a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset( + make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs( + scale_comp_buf)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = b_thread_bufs + [scale_comp_buf][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + a_blockwise_copy.Run(a_grid_desc, + a_grid_buf, + a_block_desc, + a_block_bufs(scale_comp_buf)); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + } + + constexpr auto lds_buf = + m0.value >= SwitchM ? scale_mem_buf : scale_comp_buf; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, + xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), + 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + }); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_thread_bufs(I1)); + + // Prefetch a_scales + static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc, + make_tuple(m0, k0, I0), + a_scale_thread_bufs(I1)); + + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0)); + }); + + // Prefetch b_scales + static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { + static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, k0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(0, I1, 0)); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0)); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value == SwitchM) + { + __builtin_amdgcn_s_waitcnt(async_vmcnt_encoding); + block_sync_lds(); + } + + constexpr auto lds_buf = m0.value >= SwitchM ? I1 : I0; + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(Number{}), + a_thread_desc_, + make_tuple( + I0, I0, Number{}, k, Number{}), + a_thread_buf); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I1)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I1)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I1][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I1), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + else if constexpr(TailNum == TailNumber::Odd) + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + constexpr auto im_major = m0 / MXdlPack; + constexpr auto im_minor = m0 % MXdlPack; + static_for<0, KRepeat, 1>{}([&](auto k0) { + constexpr auto ik_major = k0 / KXdlPack; + constexpr auto ik_minor = k0 % KXdlPack; + static_for<0, NRepeat, 1>{}([&](auto n0) { + constexpr auto in_major = n0 / NXdlPack; + constexpr auto in_minor = n0 % NXdlPack; + + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(im_major, ik_major, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(in_major, ik_major, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops " + "per Thread."); + + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_bufs(I0)[Number{}]; + }); + + static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) { + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_bufs(I0)[Number{}]; + }); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[I0][Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + + using mfma_input_type_b = + typename vector_type::type; + + using mfma_scale_input_type_a = + typename vector_type::type; + using mfma_scale_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = c_thread_desc_.CalculateOffset( + make_tuple(im_major, in_major, im_minor, in_minor, 0)); + + // MFMA accumulation + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + if constexpr(m0.value < (MRepeat - LocalPrefetchStages)) + { + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize * + (APackedSize * KPack / xdlops_gemm.K1PerXdlops); + static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}( + [&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_m3_k, + make_tuple(Number<((m0 + LocalPrefetchStages) / MXdlPack) % + (MRepeat / MXdlPack)>{}, + I0, + Number{}, + I0, + Number{}), + a_block_bufs(I0), + a_thread_desc_, + make_tuple(I0, + I0, + Number{}, + k, + Number{}), + a_thread_buf); + }); + }); + } + }); + } + } + + // Length: A[ARegBuf, MWave, MXdlPack, KRepeat, KPack] + // Order: 1 0 3 2 4 + static constexpr auto ARegBuf = 2; + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{}, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4>, + 4, + A_K1, + A_K1>; + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{})); + + protected: + // using Base::a_thread_copy_; + // using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp new file mode 100755 index 000000000000..9835d9325b43 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp @@ -0,0 +1,574 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimimal pipeline with highest resource request +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v4 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v4 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopUnroll = 2; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + __device__ static constexpr void HotLoopScheduler() + { + // TODO: Take data type into consideration as pipe ver 3 + // A-B splited schedule + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_a = + (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a; + constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a; + + constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = + (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b; + constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b; + + constexpr auto num_mfma_per_issue = + HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b); + + static_for<0, num_issue_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_a, + 0); // MFMA + }); + + static_for<0, num_issue_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_b, + 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + auto LoopFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto CompFunc = [&](auto mfma_reg_buf) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + // tail + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I1, I1, I0, I0); + ReadCompFunc(I0, I0, I1); + CompFunc(I0); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadCompFunc(I1, I1, I0); + CompFunc(I1); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp new file mode 100755 index 000000000000..f35c7a97cc32 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4_b_scale.hpp @@ -0,0 +1,686 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimimal pipeline with highest resource request +// GlobalPrefetchStages: 4 +// LocalPreFillStages: 2 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v4_b_scale +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v4_b_scale + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 2; + static constexpr index_t GlobalBufferNum = 1; + static constexpr index_t HotloopUnroll = 2; + + __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + __device__ static constexpr void HotLoopScheduler() + { + // TODO: Take data type into consideration as pipe ver 3 + // A-B splited schedule + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_a = + (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a; + constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a; + + constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + constexpr auto num_dswrite_per_issue_b = + (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b; + constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b; + + constexpr auto num_mfma_per_issue = + HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b); + + static_for<0, num_issue_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_a, + 0); // MFMA + }); + + static_for<0, num_issue_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) { + ignore = idsread; + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier(0x008, + num_mfma_per_issue - num_dsread_per_issue_a - + num_dswrite_per_issue_b, + 0); // MFMA + }); + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + // BScaleThreadCopy + const BScaleGridDesc& b_scale_grid_desc, + const BScaleThreadDesc& b_scale_thread_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + const BScaleThreadTransferStep& b_scale_thread_copy_step, + // num loop + index_t num_loop, + index_t num_loop_per_scale) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // B scale buffer + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + StaticallyIndexedArray{}> a_thread_bufs; + StaticallyIndexedArray{}> b_thread_bufs; + StaticallyIndexedArray{}> b_scale_thread_bufs; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(num_loop_per_scale == 1) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0)); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I1)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(2 % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Local prefetch 1 + block_sync_lds(); + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(I0), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(I0)); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(I0), + b_scale_thread_bufs(I0)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(I0)); + }); + }); + }); + + // Local prefill 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1)); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(I0)); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(3 % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + // This hot loop has two legacy loopover, to implement the double local buffer strategy + do + { + auto LoopFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + // B scale copy + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_bufs(lds_read_reg_buf)); + + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<0>{})); + }); + + if((i + 4 + mfma_reg_buf.value) % num_loop_per_scale == 0) + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, b_scale_thread_copy_step.At(Number<1>{})); + } + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf] + [Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf] + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I1, I1, I0, I0); + LoopFunc(I0, I0, I1, I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + + auto ReadWriteCompFunc = [&](auto lds_read_buf, + auto lds_read_reg_buf, + auto lds_write_buf, + auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf)); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf)); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto ReadCompFunc = [&](auto lds_read_buf, auto lds_read_reg_buf, auto mfma_reg_buf) { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf.At(lds_read_buf), + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_bufs(lds_read_reg_buf)); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf.At(lds_read_buf), + b_scale_thread_bufs(lds_read_buf)[n0], + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_bufs(lds_read_reg_buf)); + }); + }); + + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + + HotLoopScheduler(); + }; + + auto CompFunc = [&](auto mfma_reg_buf) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_bufs[mfma_reg_buf][Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[mfma_reg_buf][Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + }; + + // tail + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I1, I1, I0, I0); + ReadCompFunc(I0, I0, I1); + CompFunc(I0); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadCompFunc(I1, I1, I0); + CompFunc(I1); + } + } + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp new file mode 100755 index 000000000000..99934fa74e29 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp @@ -0,0 +1,664 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Compute optimized pipeline +// GlobalPrefetchStages: 3 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 1 +// LocalSharedMemoryBuffer: 2 + +template +struct BlockwiseGemmXdlops_pipeline_v5 +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v5 + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::A_K1; + using Base::B_K1; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::xdlops_gemm; + using typename Base::HotLoopInstList; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + static constexpr index_t PrefetchStages = 3; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 2; + static constexpr index_t HotloopUnroll = 2; + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + if(num_loop % HotloopUnroll == 1) + { + return TailNumber::Odd; + } + else + { + return TailNumber::Even; + } + } + + __device__ static constexpr auto HotLoopScheduler() + { + // TODO: Take data type into consideration as pipe ver 3 + // A/B split schedule + // compiler is likely to use ds_read2 when instruction width smaller than 16bytes + constexpr auto num_ds_read_inst_a = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 + ? HotLoopInstList::A_LDS_Read_Inst_Num + : HotLoopInstList::A_LDS_Read_Inst_Num / 2; + constexpr auto num_ds_read_inst_b = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 + ? HotLoopInstList::B_LDS_Read_Inst_Num + : HotLoopInstList::B_LDS_Read_Inst_Num / 2; + + constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num; + constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num; + + constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num; + constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num; + + constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num; + + constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle; + constexpr auto ds_read_a_issue_cycle = + HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4; + constexpr auto ds_read_b_issue_cycle = + HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4; + constexpr auto ds_read_a_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); + constexpr auto ds_read_b_mfma_rate = + (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle); + + constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1); + constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat; + constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat; + + constexpr auto num_dsread_stage1_a_mfma = + (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage1_b_mfma = + (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + constexpr auto num_dsread_stage3_a_mfma = + (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate; + constexpr auto num_dsread_stage3_b_mfma = + (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate; + + constexpr auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate - + num_ds_read_inst_b / ds_read_b_mfma_rate; + constexpr auto num_mfma_per_issue = + num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b); + constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a; + constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b; + + // stage 1 + static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // stage 2 + static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA + }); + static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) { + ignore = i; + static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) { + ignore = idswrite; + __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_group_barrier( + 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA + }); + + // stage 3 + static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >= + ds_read_a_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) { + ignore = i; + if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >= + ds_read_b_mfma_rate) + { + __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read + } + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, + num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate, + 0); // DS read + } + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + }); + + // IGLP COMPILER BUG: + // If comment out following scheduler barrier would cause sanity fail. + __builtin_amdgcn_sched_barrier(0); + } + + template + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Global prefetch 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Global prefetch 3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto vmem_buf) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + } + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + }; + + LoopFunc(I0); + LoopFunc(I1); + + i += HotloopUnroll; + } while(i < (num_loop - PrefetchStages)); + } + // tail + auto ReadWriteCompFunc = [&](auto vmem_buf) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat, 1>{}([&](auto k0) { + if constexpr(k0 == (KRepeat - 1)) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf); + + block_sync_lds(); + } + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + HotLoopScheduler(); + }; + auto ReadCompFunc = [&]() { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KRepeat - 1, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = a_thread_buf + [Number{}]; + }); + static_for<0, KPack, 1>{}([&](auto ik) { + b_thread_vec.template AsType()(ik) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + + HotLoopScheduler(); + }; + + if constexpr(TailNum == TailNumber::Odd) + { + ReadWriteCompFunc(I0); + ReadWriteCompFunc(I1); + ReadCompFunc(); + } + else if constexpr(TailNum == TailNumber::Even) + { + ReadWriteCompFunc(I0); + ReadCompFunc(); + } + } + + protected: + // A[MRepeat, I1, I1, KPack] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, Number{})); + + // B[NRepeat, N1, N2, KPack] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()}; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp new file mode 100755 index 000000000000..e9f9b0be7eab --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_smfmac_xdlops.hpp @@ -0,0 +1,453 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +__host__ __device__ static constexpr auto +MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) +{ + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); +} + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + SparseXdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0, + "MPerBlock must be divisible by MPerXDL * MRepeat"); + static_assert(NPerBlock % (NPerXDL * NRepeat) == 0, + "NPerBlock must be divisible by NPerXDL * NRepeat"); + + static_assert( + KPack % (16 * sizeof(ComputeTypeA)) == 0, + "KPack must be divisbile by number of elements processed in single smfmac instruction"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + // Prepares data in a_thread_buf by squeezing values by ommiting zeros to adjust it to 2:4 + // structural sparsity. The indexes of non-zero elements are stored in idx_buf and used later in + // smfmac instruction + template + __device__ void SetIdxSqueezeA(AThreadBuf& a_thread_buf, IdxBuf& idx_buf) + { + static constexpr int32_t bit_clear_masks[4] = {0b11, 0b1100, 0b110000, 0b11000000}; + static constexpr int32_t processed_elems = 16 / sizeof(ComputeTypeA); + + static_for<0, num_elems, processed_elems>{}([&](auto i) { + constexpr int idx_reg_num = i / (16 * sizeof(ComputeTypeA)); + constexpr int idx_reg_part = (i % 32) / processed_elems; + + vector_type a_thread_vec; + static_for<0, processed_elems, 1>{}([&](auto j) { + a_thread_vec.template AsType()(j) = a_thread_buf + [Number{}]; + }); + + uint8_t idx = 0b11101110; // set to last 2 elems for both 4-elems subgroups by default + for(int j = 0; j < processed_elems; j += 4) + { + int32_t a_pos = idx_reg_part * processed_elems + j; + int32_t nonzero_pos = 0; + ComputeTypeA nonzero_elems[2] = {a_thread_vec[j + 2], a_thread_vec[j + 3]}; + for(int k = 0; k < 3; k += 1) + { + if(a_thread_vec[j + k] != 0.0f) + { + nonzero_elems[nonzero_pos] = a_thread_vec[j + k]; + idx &= ~bit_clear_masks[j / 2 + nonzero_pos]; + idx |= k << 2 * (j / 2 + nonzero_pos); + ++nonzero_pos; + } + } + a_thread_vec[j / 2] = nonzero_elems[0]; + a_thread_vec[j / 2 + 1] = nonzero_elems[1]; + } + IdxBuf[idx_reg_num].AsType()[Number{}] = idx; + + static_for<0, processed_elems / 2, 1>{}([&](auto j) { + a_thread_buf[Number{}] = a_thread_vec[j]; + }); + }); + } + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + static constexpr int32_t elems_per_idx = 16 * sizeof(ComputeTypeA); + auto idx_buf = make_static_buffer( + (a_thread_desc_.GetElementSpaceSize() + elems_per_idx - 1) / elems_per_idx); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + SetIdxSqueezeA(a_thread_buf, idx_buf, a_thread_desc_.GetElementSpaceSize()); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + // a_thread_vec is smaller because it's structurally sparse 2:4 + vector_type a_thread_vec; + vector_type b_thread_vec; + vector_type idx_vec; + + static_for<0, KPack / 2, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(2 * i) = b_thread_buf + [Number{}]; + }); + + static_for<0, KPack / elems_per_idx, 1>{}([&](auto i) { + idx_vec.template AsType()(i) = idx_buf[k / elems_per_idx + i]; + }); + + // A is smaller because it's structurally sparse 2:4 + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + using mfma_input_type_idx = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + idx_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp new file mode 100755 index 000000000000..fa389c3402d1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp @@ -0,0 +1,1030 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#define CK_MNK_LOOP + +namespace ck { + +#ifdef __gfx12__ +template +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ +struct BlockwiseGemmWMMA +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = 2; + static constexpr index_t B_KRow = 2; + + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + // Default, Block buffer in LDS, thread level offset enabled + __device__ static auto CalculateAThreadOriginDataIndex() + { + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, wmma_gemm.GetSubGroupId(), WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, wmma_gemm.GetSubGroupId(), WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_assert(KPack % (A_K1 * A_KRow) == 0, ""); + static_assert(KPack % (B_K1 * B_KRow) == 0, ""); + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + static_for<0, KPack / B_KRow, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, I1, I1, I1, Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + false>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; +#else +template +/* Option: Read from LDS, big buffer hold all threads required data + * Source + * A: K0PerBlock x MPerBlock x K1 + * B: K0PerBlock x NPerBlock x K1 + * Destination + * C, non-transpose + * thread level: MRepeat x NRepeat x MAccVgprs + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + * KPACK == WMMA_K = 16 + * + * Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS) + * Source: + * A(if skip LDS): MRepeat x KPack + * B(if skip LDS): NRepeat x KPack + * Destination + * C, non-transpose + * block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs + */ +struct BlockwiseGemmWMMA +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto WmmaK = Number<16>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one. + static constexpr index_t WaveSize = 32; + + // When use LDS, each Row(16 consecutive lanes) read whole data from source buffer + // When not use LDS, each Row read half of whole data from source buffer, exchange the data via + // permutation + static constexpr index_t A_KRow = AEnableLds ? 1 : 2; + static constexpr index_t B_KRow = BEnableLds ? 1 : 2; + static constexpr index_t A_K1 = ABlockDesc{}.GetLength(I5); + static constexpr index_t B_K1 = BBlockDesc{}.GetLength(I5); + + static constexpr auto wmma_gemm = + WmmaGemm{}; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWMMA); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWMMA); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + // Default, Block buffer in LDS, thread level offset enabled + __device__ static auto CalculateAThreadOriginDataIndex() + { + if constexpr(AEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); + + // |KRepeat |MRepeat|MWave |KRow |MLane |KPack + return make_tuple(0, 0, waveId_m, 0, WMMA_a_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + if constexpr(BEnableLds) + { + const auto wave_idx = GetWaveIdx(); + const auto waveId_n = wave_idx[I1]; + const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); + + // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack + return make_tuple(0, 0, waveId_n, 0, WMMA_b_idx, 0); + } + else + { + return make_tuple(0, 0, 0, 0, 0, 0); + } + } + + template + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk(); + + constexpr auto mrepeat_mwave_mperWMMA_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperWMMA_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperWMMA_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperWMMA_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto CalculateCThreadOriginDataIndex7D(Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk3D(); + + return make_tuple( + Number{}, waveId_m, blk_idx[I0], Number{}, waveId_n, blk_idx[I1], blk_idx[I2]); + } + + using Tuple6 = decltype(CalculateAThreadOriginDataIndex()); + __host__ __device__ BlockwiseGemmWMMA(Tuple6 a_origin = CalculateAThreadOriginDataIndex(), + Tuple6 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerWMMA * MRepeat) == 0 && + NPerBlock % (NPerWMMA * NRepeat) == 0, + "wrong!"); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + + return make_naive_tensor_descriptor_packed( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, NAccVgprs)); + } + + // Thread level, register decriptor. Vector-write + __host__ __device__ static constexpr auto + GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens = + wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths(); + + constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2]; + constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3]; + return make_naive_tensor_descriptor( + // |MRepeat |MWave |MSubGroup |NRepeat |NWave + // |NThreadPerSubGroup |MAccVgprs + make_tuple(Number{}, I1, I1, Number{}, I1, I1, MAccVgprs), + make_tuple(Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + Number{} * MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + MAccVgprs * AccStride, + AccStride)); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple( + make_unmerge_transform(make_tuple(M / (MWaves * MPerWMMA), MWaves, MPerWMMA)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerWMMA), NWaves, NPerWMMA))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma); + } + + // transposed WMMA output C' = B' * A' + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Provide dimension size + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs() + { + constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return wmma_gemm + .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma); + } + + // Describe how data allocated in thread copy src buffer + // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma + static constexpr ABlockDesc a_block_desc_k0_m0_m1_m2_k1; + static constexpr BBlockDesc b_block_desc_k0_n0_n1_n2_k1; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + // basic intrinsic to determine loopover direction + if constexpr(MRepeat < NRepeat) + { + static_for<0, KPerBlock / KPack, 1>{}( + [&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KPerBlock / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of + // k=0,kpack*1, .. + // read B + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, I0, I0, I0, I0), + b_thread_buf); + // read A + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, I0, I0, I0, I0), + a_thread_buf); + + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + }); + + using wmma_input_type_a = typename vector_type::type; + using wmma_input_type_b = typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + } + + protected: + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{}), + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number<1>{})); + + // C[M, N, NumRegWMMA] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, wmma_gemm.GetRegSizePerWmma())); + + template + struct AThreadCopySelector; + + template <> + struct AThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + A_K1>; + }; + + template <> + struct AThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatA, + FloatA, + decltype(a_block_desc_k0_m0_m1_m2_k1), + decltype(a_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + A_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? false : true>; + }; + + template + struct BThreadCopySelector; + + template <> + struct BThreadCopySelector + { + using type = + ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + B_K1>; + }; + + template <> + struct BThreadCopySelector + { + using type = ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow< + FloatB, + FloatB, + decltype(b_block_desc_k0_n0_n1_n2_k1), + decltype(b_thread_desc_), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5>, + 5, + B_K1, + 0x76543210, + 0xfedcba98, + TransposeC ? true : false>; + }; + + typename AThreadCopySelector::type a_thread_copy_; + typename BThreadCopySelector::type b_thread_copy_; +}; +#endif + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp new file mode 100755 index 000000000000..d3f6344c27eb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -0,0 +1,1016 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +__host__ __device__ static constexpr auto +MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&) +{ + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); +} + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); + static constexpr index_t KPerBlock = + BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple(Number{}, + Number{}, + waveId_m, + waveId_n, + blk_idx[I0], + blk_idx[I1], + blk_idx[I2], + blk_idx[I3]); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; +}; + +// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro +// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in +// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the +// default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0 +template +struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 + : public BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 +{ + using Base = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + +#if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING + using Base::a_block_desc_m0_m1_m2_k; + using Base::A_K1; + using Base::b_block_desc_n0_n1_n2_k; + using Base::B_K1; + using Base::c_thread_buf_; + using Base::c_thread_desc_; + using Base::CalculateAThreadOriginDataIndex; + using Base::CalculateBThreadOriginDataIndex; + using Base::I0; + using Base::I1; + using Base::KPerThread; + using Base::xdlops_gemm; + + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); + + // 2-wave optimized blockwise gemm + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, k), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, k), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, I0, I0), + b_thread_buf); + }); + __builtin_amdgcn_sched_barrier(0); + // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except + // the first, as we can shorten non-MAC cluster a bit and there's no observable negative + // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids + // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the + // chance of latency hiding by waiting for the rest of the workgroup at the eventual + // sync point. + if constexpr(k.value != 0 || KPerInnerLoop == KPerThread) + { +#ifdef __gfx12__ + asm volatile("\ + s_barrier_signal -1 \n \ + s_barrier_wait -1 \ + " ::); +#else + asm volatile("s_barrier" ::); +#endif + __builtin_amdgcn_sched_barrier(0); + } + static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(i) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // The block_sync_lds() here performs double duty: + // A) safeguard against data hazard because barrier from blockwise_gemm is + // moved here B) reduce VMEM FIFO congestion by applying small delays to + // different wavefronts It is performed near the end of MAC cluster to + // minimize lgkmcnt penalty + if constexpr(k.value == KPerThread - KPerInnerLoop && + k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 && + n0.value == NRepeat - 1) + { + __builtin_amdgcn_sched_barrier(0); + block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); + } + + // TODO: insert setprio in more precise manner since we + // could have more than >1 MFMA instructions in single call + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(1); + __builtin_amdgcn_sched_barrier(0); + } + }); + }); + }); + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_setprio(0); + __builtin_amdgcn_sched_barrier(0); + }); + } + + protected: + // A[M0, M1, M2, KPerInnerLoop] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + // B[N0, N1, N2, KPerInnerLoop] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, I1, Number{})); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; + +#endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING +}; + +template +constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + } +}; + +/** + * @brief Blockwise gemm + * + * Supports + * 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4 + * 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS + * source buffer + * 3. configurable k index starting position and step size after each FMA/XDL instruction + */ + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename ATileDesc, + typename BTileDesc, + typename AMmaTileDesc, + typename BMmaTileDesc, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerXDL, + index_t NPerXDL, + index_t MRepeat, + index_t NRepeat, + index_t KPack, + bool TransposeC = false, + index_t AMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops, + index_t BMmaKStride = + KPack* XdlopsGemm{}.K0PerXdlops> +struct BlockwiseGemmXdlops_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex8D(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i); + + return make_tuple( + m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + __host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_block_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ... + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0), + b_thread_buf); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + protected: + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp new file mode 100755 index 000000000000..287c6701c339 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp @@ -0,0 +1,320 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr index_t WaveSize = 64; + + static constexpr index_t KPerBlock = K0PerBlock * KPack; + + static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0); + static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1() + { + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0K0BN0N1N2N3K1BlockDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(BlockSize == MWaves * NWaves * WaveSize, + "BlockSize != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ void MoveABlockSliceWindow() + { + a_thread_copy_.MoveSrcSliceWindow(a_block_desc_m0_m1_m2_k, + make_multi_index(0, 0, 0, K0PerBlock * KPack)); + } + __device__ void ResetABlockStartWindow() + { + a_thread_copy_.SetSrcCoord(CalculateAThreadOriginDataIndex()); + } + + static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K(); + + template + __device__ void Run(const ABlockBuffer& a_block_buf, + const BBlockBuffer& b_thread_buf, + CThreadBuffer& c_thread_buf) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + // read A + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, I0, I0, I0), + a_thread_buf); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + static_for<0, KPerThread, KPack>{}([&](auto k) { + vector_type a_thread_vec; + vector_type b_thread_vec; + constexpr index_t k0 = k / KPack; + static_for<0, KPack, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); + }); + }); + }); + } + + private: + // A[M0, M1, M2, KPerThread] + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number{})); + + // B[N0, N1, N2, KPerThread] + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, // KPerThread + Number{}, // repeat + Number{})); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp new file mode 100755 index 000000000000..2fb7242708da --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_softmax.hpp @@ -0,0 +1,125 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" + +namespace ck { + +/** + * @brief Blockwise softmax + * + * @tparam BlockSize Block size + * @tparam AccDataType Accumulator data type + * @tparam ThreadMap_M_K Thread id to m_k + * @tparam ThreadClusterDesc_M_K Threadwise cluster descriptor + * @tparam ThreadSliceDesc_M_K Threadwise slices descriptor + * @tparam IgnoreNaN Flag to ignore NaN, false by default + */ +template +struct BlockwiseSoftmax +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr index_t MRepeat = ThreadSliceDesc_M_K{}.GetLength(I0); + static constexpr index_t KRepeat = ThreadSliceDesc_M_K{}.GetLength(I1); + + using ThreadSliceDesc_M = decltype(make_naive_tensor_descriptor_packed( + make_tuple(ThreadSliceDesc_M_K{}.GetLength(I0)))); + + using ThreadwiseMaxReduce = typename conditional< + IgnoreNaN, + ThreadwiseReduction>, + ThreadwiseReduction>::type; + + using ThreadwiseSumReduce = typename conditional< + IgnoreNaN, + ThreadwiseReduction>, + ThreadwiseReduction>::type; + + using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); + + using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction_v2; + + using BufferType = StaticBuffer; + + template + __host__ __device__ void Run(CThreadBuffer& in_thread_buf, WorkspaceBuffer& reduce_work_buf) + { + // find max value + static_for<0, MRepeat, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + static_for<0, MRepeat, 1>{}([&](auto I) { + BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); + block_sync_lds(); + }); + + // calculate exp for elements, P=exp(s-max) + static_for<0, MRepeat, 1>{}([&](auto iM) { + static_for<0, KRepeat, 1>{}([&](auto iK) { + auto offset = Number{}; + in_thread_buf(offset) = IgnoreNaN && ck::math::isnan(in_thread_buf[offset]) + ? 0 + : math::exp(in_thread_buf[offset] - max_value_buf(iM)); + }); + }); + + // sum data + static_for<0, MRepeat, 1>{}([&](auto I) { + sum_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + ThreadwiseSumReduce::Reduce(in_thread_buf, sum_value_buf); + static_for<0, MRepeat, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, sum_value_buf(I)); + block_sync_lds(); + }); + } + + BufferType max_value_buf; + BufferType sum_value_buf; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp new file mode 100755 index 000000000000..d8da134a3415 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp @@ -0,0 +1,156 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct BlockwiseTensorSliceTransfer_v5r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + using Index = MultiIndex; + + __device__ constexpr BlockwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + : threadwise_transfer_( + src_desc, make_zero_multi_index(), dst_desc, make_zero_multi_index()) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(), + "wrong! BlockSize too small"); + + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{}; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + // SrcMoveSliceWindowStepHack to control index calculation move slice window + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& step, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow( + src_desc, step, src_move_slice_window_step_hack); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(BlockSize == thread_cluster_desc_.GetElementSize() or + get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v5r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp new file mode 100755 index 000000000000..820a08fc482b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/blockwise_welford.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/get_shift.hpp" + +namespace ck { + +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace +// 2) work_buffer has T elements, and space size is no less than 3*BlockSize +// 3) mean_value, var_value and count is the input data in vgpr from each thread +// 4) mean_value, var_value and count is the over-written reduced output in vgpr for each thread +// 5) Merge mean and M from ThreadwiseWelford +// clang-format on +template +struct BlockwiseWelford +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + template + __device__ static inline void + Merge(T& mean_a, T& var_a, CountDataType& count_a, T mean_b, T var_b, CountDataType count_b) + { + CountDataType count = count_a + count_b; + T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a * count_b_over_count; + count_a = count; + } + + template + __device__ static void Run(T& mean_value, T& var_value, CountDataType& count) + { + __shared__ T mean_block_buf[BlockSize]; + __shared__ T var_block_buf[BlockSize]; + __shared__ CountDataType count_block_buf[BlockSize]; + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + + mean_block_buf[offset1] = mean_value; + var_block_buf[offset1] = var_value; + count_block_buf[offset1] = count; + + block_sync_lds(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + T mean1 = mean_block_buf[offset1]; + T var1 = var_block_buf[offset1]; + CountDataType count1 = count_block_buf[offset1]; + + T mean2 = mean_block_buf[offset2]; + T var2 = var_block_buf[offset2]; + CountDataType count2 = count_block_buf[offset2]; + + Merge(mean1, var1, count1, mean2, var2, count2); + + mean_block_buf[offset1] = mean1; + var_block_buf[offset1] = var1; + count_block_buf[offset1] = count1; + } + + block_sync_lds(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + count = count_block_buf[offset]; + mean_value = mean_block_buf[offset]; + + if constexpr(GetActualVariance) + var_value = var_block_buf[offset] / count; + else + var_value = var_block_buf[offset]; + }; +}; +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp new file mode 100755 index 000000000000..82667e235238 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/get_shift.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize +// 3) in_out_value is the input data in vgpr from each thread +// 4) in_out_value is the over-written reduced output in vgpr for each thread +// clang-format on +template > +struct PartitionedBlockwiseReduction +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + template + __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_buffer[offset1]; + AccDataType opData2 = work_buffer[offset2]; + Accumulation::Calculate(opData1, opData2); + work_buffer(offset1) = opData1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_buffer[offset]; + }; +}; + +// clang-format off +// Assume: +// 1) work_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_buffer has AccDataType elements, and space size is no less than BlockSize +// 3) in_out_value is the input data in vgpr from each thread +// 4) in_out_value is the over-written reduced output in vgpr for each thread +// clang-format on +template > +struct PartitionedBlockwiseReduction_v2 +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = ThreadClusterDesc{}; + + template + __device__ static void Reduce(BufferType& work_buffer, AccDataType& in_out_value) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << (cluster_len_shift - 1 - I()); + + if(thread_k_cluster_id < indOffset) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_buffer[offset1]; + AccDataType opData2 = work_buffer[offset2]; + Accumulation::Calculate(opData1, opData2); + work_buffer(offset1) = opData1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_buffer[offset]; + }; +}; + +// clang-format off +// Assume: +// 1) work_val_buffer/work_idx_buffer is buffer (typically LDS) allocated outside as workspace, does not include any in/out data +// 2) work_val_buffer/work_idx_buffer has AccDataType/IndexDataType elements, and space size is no less than BlockSize +// 3) in_out_value/in_out_index is the input data in vgpr from each thread +// 4) in_out_value/in_out_index is the over-written reduced output in vgpr for each thread +// clang-format on +template < + typename AccDataType, + typename IndexDataType, + index_t BlockSize, + typename ThreadClusterLengths_M_K, + typename ThreadClusterArrangeOrder, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> +struct PartitionedBlockwiseReductionWithIndex +{ + static_assert(BlockSize == ThreadClusterLengths_M_K::At(0) * ThreadClusterLengths_M_K::At(1), + "The product of cluster lengths should be same as BlockSize!"); + + static constexpr auto BufferLength_M = ThreadClusterLengths_M_K::At(0); + static constexpr auto BufferLength_K = ThreadClusterLengths_M_K::At(1); + + static_assert(BufferLength_K > 1, "Parallel reduction need work on at least two elements"); + + static constexpr auto block_buf_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + // This interface accumulates on both data values and indices + template + __device__ static void Reduce(BufferType& work_val_buffer, + IdxBufferType& work_idx_buffer, + AccDataType& in_out_value, + IndexDataType& in_out_index) + { + static_assert(is_same{}, + "Buffer data type should be consistent as AccDataType!"); + static_assert(is_same{}, + "Buffer data type should be consistent as IndexDataType!"); + + constexpr auto cluster_len_shift = get_shift(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + + const auto thread_m_cluster_id = thread_cluster_idx[Number<0>{}]; + const auto thread_k_cluster_id = thread_cluster_idx[Number<1>{}]; + + work_val_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_value; + work_idx_buffer(block_buf_desc_m_k.CalculateOffset(thread_cluster_idx)) = in_out_index; + + __syncthreads(); + + static_for<0, cluster_len_shift, 1>{}([&](auto I) { + constexpr index_t indOffset = 1 << I(); + + if(thread_k_cluster_id % (indOffset * 2) == 0) + { + index_t offset1 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx); + index_t offset2 = block_buf_desc_m_k.CalculateOffset(thread_cluster_idx + + make_tuple(0, indOffset)); + + AccDataType opData1 = work_val_buffer[offset1]; + AccDataType opData2 = work_val_buffer[offset2]; + IndexDataType currIndex1 = work_idx_buffer[offset1]; + IndexDataType currIndex2 = work_idx_buffer[offset2]; + + Accumulation::Calculate(opData1, opData2, currIndex1, currIndex2); + work_val_buffer(offset1) = opData1; + work_idx_buffer(offset1) = currIndex1; + } + + __syncthreads(); + }); + + index_t offset = block_buf_desc_m_k.CalculateOffset(make_tuple(thread_m_cluster_id, 0)); + + in_out_value = work_val_buffer[offset]; + in_out_index = work_idx_buffer[offset]; + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp new file mode 100755 index 000000000000..98cc149f4d9b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -0,0 +1,349 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +/** + * Transfer that uses direct load instructions to copy data from global to LDS memory. + * + * Traditional loads first copy data from global to registers, and then from registers to LDS. + * Direct loads do not need an intermediate step, data is copied directly from global to LDS, + * without the use of additional registers. + * + * However, the instruction has limitations: + * - each thread must copy exactly a single DWORD - 4 bytes; + * - threads within a single wavefront must write consecutive DWORDS into LDS, + * (data in global do not need to be contiguous, each thread might have its own offset). + * + * To make sure that all the transfers finished, the `waitcnt` instruction must be used with + * `vmcnt` instead of `lgkmcnt`. + * + * Limitations of the transfer class: + * - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight; + * - `DstVectorDim` must be the last dimension; + * - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1; + * - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B + * (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16, + * `ScalarPerVector` must be 2); + * - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be + * the same dimension; + * - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64, + * they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way + * to guarantee that. + */ +template +struct ThreadGroupTensorSliceTransfer_DirectLoad +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr auto block_slice_lengths = BlockSliceLengths{}; + static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + + static constexpr auto thread_single_load_size = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + // After a load, each thread moves by `thread_steps` instead of loading the next elements. + // It makes the whole wavefront load contiguous memory, what is required for direct loads. + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; + + static __device__ constexpr bool AreThreadClusterLengthsValid() + { + // Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to + // LDS by the threads from a single wavefront. + // Examples (assuming 64 threads in a wavefront, 128 in a thread block): + // 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp32 -> ScalarPerVector = 1 + // INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of + // [0, 4, 0]. + // VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration, + // threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs). + // 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp16 -> ScalarPerVector = 2 + // NOTE: ThreadClusterLengths must take into account that each thread writes two + // elements (single DWORD) along the contiguous dimension. + // INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write + // 8 * 2 elements of K1PerBlock and there are only 8; + // ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32 + // writes [1, 0, 0] instead of [0, 8, 0]. + // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the + // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive + // elements = 64 consecutive DWORDs. +#if defined(__gfx950__) + int num_contiguous_dwords = 4; +#else + int num_contiguous_dwords = 1; +#endif + bool is_contiguous = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(is_contiguous) + { + num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1]; + } + if(thread_slice_lengths[nDim - i - 1] > 1) + { + is_contiguous = false; + } + }); + constexpr index_t wavefront_size = get_warp_size(); + const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0; + + bool thread_slice_lengths_correct = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(thread_slice_lengths[i] <= 0) + { + thread_slice_lengths_correct = false; + } + }); + + return wave_contiguous && thread_slice_lengths_correct; + } + + __device__ constexpr ThreadGroupTensorSliceTransfer_DirectLoad( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin) + + { + static_assert(ck::is_same_v, + "Direct load transfer does not support datatypes conversion. Source and " + "destination data types must be the same."); + + static_assert( + DstVectorDim == nDim - 1, + "Direct load transfer requires the destination vector dimension to be the last one."); + + static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, + "When loading more than one element per thread at once, the contiguous " + "dimension must be the same between source and destination."); + + // constexpr auto dword_bytes = 4; + // constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); + // static_assert(bytes_per_thread_load == dword_bytes, + // "Direct load transfer requires each thread to load exactly a single " + // "DWORD of data."); + + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size(), + "Inconsistent number of dimensions across lengths and descriptors."); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "The number of threads cannot be less than the number of elements in " + "thread cluster lengths."); + + // static_assert( + // AreThreadClusterLengthsValid(), + // "Thread cluster lengths are incorrect. They must be set in a way that allows a single + // " "wavefront to write contiguous DWORDs into LDS memory. "); + + const auto thread_cluster_idx = + thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + + constexpr auto wave_cluster_lengths = generate_sequence_v2( + [&](auto i) { + // FIXME: wave parallelism is not always in that dimension. + // The ThreadClusterLengths{} must be bigger than wave_num; + if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3)) + { + return Number{}; + } + else + { + return I1; + } + }, + Number{}); + + constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths; + constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; + constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); + + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() / 64)); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; + + SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); + // We don't need threadwise offset for lds since it was calculate by HW + // We still need input the wavewise offset. + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + src_slice_origin_ = src_slice_origin_idx; + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_slice_origin_ = dst_slice_origin_idx; + } + + __device__ void ResetDstSliceWindow(const DstDesc& dst_desc) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + static_assert( + ck::is_same_v, remove_cvref_t>, + "SrcBuffer and SrcData data types must be consistent."); + static_assert( + ck::is_same_v, remove_cvref_t>, + "DstBuffer and DstData data types must be consistent."); + + constexpr auto dst_access_lengths = thread_slice_lengths; + + const auto dst_forward_steps = generate_steps(dst_desc, 1); + const auto dst_backward_steps = generate_steps(dst_desc, -1); + const auto src_forward_steps = generate_steps(src_desc, 1); + const auto src_backward_steps = generate_steps(src_desc, -1); + + // Loop over the destination block and copy data. + static_ford{}([&](auto ordered_dst_access_idx) { + const auto src_offset = src_coord_.GetOffset(); + const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset()); + + // Check if src data is not in the logic padding area. + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + src_buf.template DirectCopyToLds, ScalarPerVector>( + dst_buf, src_offset, dst_offset, is_src_valid); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // Decide whether to move forward or backward. + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]); + } + else + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]); + move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]); + } + } + }); + }); + + // Reset the destination slice since the entire buffer has been already filled. + ResetDstSliceWindow(dst_desc); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + src_slice_origin_ = src_slice_origin_ + step; + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_); + } + + template + __device__ auto generate_steps(const DescType& desc, int sign) + { + return generate_tuple( + [&](auto i) { + Index step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0; + }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + SrcCoord src_coord_; + DstCoord dst_coord_; + Index src_slice_origin_; + Index dst_slice_origin_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp new file mode 100755 index 000000000000..3e9e50112616 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp @@ -0,0 +1,405 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +/** + * Transfer that uses direct load instructions to copy data from global to LDS memory. + * + * Traditional loads first copy data from global to registers, and then from registers to LDS. + * Direct loads do not need an intermediate step, data is copied directly from global to LDS, + * without the use of additional registers. + * + * However, the instruction has limitations: + * - each thread must copy exactly a single DWORD - 4 bytes; + * - threads within a single wavefront must write consecutive DWORDS into LDS, + * (data in global do not need to be contiguous, each thread might have its own offset). + * + * To make sure that all the transfers finished, the `waitcnt` instruction must be used with + * `vmcnt` instead of `lgkmcnt`. + * + * Limitations of the transfer class: + * - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight; + * - `DstVectorDim` must be the last dimension; + * - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1; + * - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B + * (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16, + * `ScalarPerVector` must be 2); + * - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be + * the same dimension; + * - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64, + * they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way + * to guarantee that. + */ +template +struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr auto block_slice_lengths = BlockSliceLengths{}; + static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + + static constexpr auto thread_single_load_size = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + // After a load, each thread moves by `thread_steps` instead of loading the next elements. + // It makes the whole wavefront load contiguous memory, what is required for direct loads. + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; + static constexpr index_t gather_num = thread_slice_lengths.At(Number{}); + + static __device__ constexpr bool AreThreadClusterLengthsValid() + { + // Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to + // LDS by the threads from a single wavefront. + // Examples (assuming 64 threads in a wavefront, 128 in a thread block): + // 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp32 -> ScalarPerVector = 1 + // INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of + // [0, 4, 0]. + // VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration, + // threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs). + // 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8], + // data type = fp16 -> ScalarPerVector = 2 + // NOTE: ThreadClusterLengths must take into account that each thread writes two + // elements (single DWORD) along the contiguous dimension. + // INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write + // 8 * 2 elements of K1PerBlock and there are only 8; + // ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31 + // write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32 + // writes [1, 0, 0] instead of [0, 8, 0]. + // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the + // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive + // elements = 64 consecutive DWORDs. +#if defined(__gfx950__) + int num_contiguous_dwords = 4; +#else + int num_contiguous_dwords = 1; +#endif + bool is_contiguous = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(is_contiguous) + { + num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1]; + } + if(thread_slice_lengths[nDim - i - 1] > 1) + { + is_contiguous = false; + } + }); + constexpr index_t wavefront_size = get_warp_size(); + const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0; + + bool thread_slice_lengths_correct = true; + static_for<0, nDim, 1>{}([&](auto i) { + if(thread_slice_lengths[i] <= 0) + { + thread_slice_lengths_correct = false; + } + }); + + return wave_contiguous && thread_slice_lengths_correct; + } + + __device__ constexpr ThreadGroupTensorSliceTransfer_Gather_DirectLoad( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const StaticallyIndexedArray& gather_offsets) + : gather_offsets_(gather_offsets) + { + static_assert(ck::is_same_v, + "Direct load transfer does not support datatypes conversion. Source and " + "destination data types must be the same."); + + static_assert( + DstVectorDim == nDim - 1, + "Direct load transfer requires the destination vector dimension to be the last one."); + + static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim, + "When loading more than one element per thread at once, the contiguous " + "dimension must be the same between source and destination."); + + // constexpr auto dword_bytes = 4; + // constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); + // static_assert(bytes_per_thread_load == dword_bytes, + // "Direct load transfer requires each thread to load exactly a single " + // "DWORD of data."); + + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size(), + "Inconsistent number of dimensions across lengths and descriptors."); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "The number of threads cannot be less than the number of elements in " + "thread cluster lengths."); + + // static_assert( + // AreThreadClusterLengthsValid(), + // "Thread cluster lengths are incorrect. They must be set in a way that allows a single + // " "wavefront to write contiguous DWORDs into LDS memory. "); + + const auto thread_cluster_idx = + thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + + constexpr auto wave_cluster_lengths = generate_sequence_v2( + [&](auto i) { + if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3)) + { + return Number{}; + } + else + { + return I1; + } + }, + Number{}); + + constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths; + constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; + constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); + + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() / 64)); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; + + SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); + // We don't need threadwise offset for lds since it was calculate by HW + // We still need input the wavewise offset. + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + auto adjusted_src_origin_idx = [&]() { + Index idx; + static_for<0, nDim, 1>{}([&](auto i) { + idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number{}]; + }); + return idx; + }(); + + // CK_PRINT(); + // CK_PRINT(); + + src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx); + src_slice_origin_ = adjusted_src_origin_idx; + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + dst_slice_origin_ = dst_slice_origin_idx; + } + + __device__ void ResetDstSliceWindow(const DstDesc& dst_desc) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + static_assert( + ck::is_same_v, remove_cvref_t>, + "SrcBuffer and SrcData data types must be consistent."); + static_assert( + ck::is_same_v, remove_cvref_t>, + "DstBuffer and DstData data types must be consistent."); + + constexpr auto dst_access_lengths = thread_slice_lengths; + + const auto dst_forward_steps = generate_steps(dst_desc, 1); + const auto dst_backward_steps = generate_steps(dst_desc, -1); + const auto src_forward_steps = generate_steps(src_desc, 1); + const auto src_backward_steps = generate_steps(src_desc, -1); + + // Loop over the destination block and copy data. + static_ford{}([&](auto ordered_dst_access_idx) { + IndexType gather_offset = gather_offsets_[ordered_dst_access_idx[Number{}]]; + // src_coord_xor_ = src_coord_; + // src_coord_xor_.GetIndex().At(I0) = + // src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8); + Index new_index = src_coord_.GetIndex(); + new_index(I0) = src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8); + src_coord_xor_ = make_tensor_coordinate(src_desc, new_index); + + const IndexType src_offset = src_coord_xor_.GetOffset() + gather_offset; + const IndexType dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset()); + + // Check if src data is not in the logic padding area. + // Leave the HW for oob checking + // const bool is_src_valid = + // coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, + // src_coord_); + + src_buf.template DirectCopyToLds, ScalarPerVector>( + dst_buf, src_offset, dst_offset, true); + + constexpr auto move_src_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1; + }); + move_on_dim_(i) &= i.value != GatherDim; + }); + + return move_on_dim_; + } + (); + + constexpr auto move_dst_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // Decide whether to move forward or backward. + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + static_for<0, nDim, 1>{}([&](auto i) { + // Move the source coordinate. + if constexpr(move_src_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]); + } + else + { + move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]); + } + } + + // Move the destination coordinate. + if constexpr(move_dst_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]); + } + else + { + move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]); + } + } + }); + }); + + // Reset the destination slice since the entire buffer has been already filled. + ResetDstSliceWindow(dst_desc); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + src_slice_origin_ = src_slice_origin_ + step; + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_); + } + + template + __device__ auto generate_steps(const DescType& desc, int sign) + { + return generate_tuple( + [&](auto i) { + Index step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0; + }); + + return make_tensor_coordinate_step(desc, step_idx); + }, + Number{}); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + SrcCoord src_coord_; + SrcCoord src_coord_xor_; + DstCoord dst_coord_; + Index src_slice_origin_; + Index dst_slice_origin_; + StaticallyIndexedArray gather_offsets_; + // static constexpr auto a_grid_xor_desc = make_naive_tensor_descriptor_packed( + // make_tuple(Number{}, Number{}, Number{})); +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp new file mode 100755 index 000000000000..bbbe012730fd --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,199 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ constexpr auto GetSrcThreadScratchIdx() + { + return threadwise_transfer_.template GetSrcThreadScratchIdx(); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp new file mode 100755 index 000000000000..ab826bb04164 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer with dequantization + * + * RunRead would load low-precision data and scale data. + * RunWrite would process dequantization process. + * Assume Scale is identical along K-dimension + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r1_dequant +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr auto scale_thread_slice_lengths = + BlockScaleSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_block_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + scale_desc, + make_zero_multi_index(), + scale_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{} && + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetScaleSliceOrigin( + scale_desc, scale_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + // With the assumption, scale scratch is always one + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunScaleRead(scale_desc, scale_buf); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + // We don't prefer use this API directly + /* + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + */ + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + // With the assumption, scale buffer don't need move slice window method + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1_dequant; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp new file mode 100755 index 000000000000..92aef65388f0 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp @@ -0,0 +1,201 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r1_gather +{ + static constexpr auto I0 = Number<0>{}; + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + static constexpr index_t gather_num = thread_slice_lengths.At(Number{}); + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_gather( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const DstElementwiseOperation& dst_element_op, + const StaticallyIndexedArray& gather_offsets) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + src_element_op, + dst_desc, + make_zero_multi_index(), + dst_element_op, + gather_offsets) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + threadwise_transfer_.SetSrcSliceOrigin( + src_desc, src_block_slice_origin + thread_cluster_idx * thread_slice_lengths); + threadwise_transfer_.SetDstSliceOrigin( + dst_desc, dst_block_slice_origin + thread_cluster_idx * thread_slice_lengths); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ constexpr auto GetSrcThreadScratchIdx() + { + return threadwise_transfer_.template GetSrcThreadScratchIdx(); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_desc, src_buf, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id) + { + RunRead(src_desc, src_buf, thread_scratch_id); + RunWrite(dst_desc, dst_buf, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r1_gather; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp new file mode 100755 index 000000000000..aa1f7c573597 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp" + +namespace ck { + +/** + * @brief Blockwise data transfer + * + * This version does following things to avoid scratch memory issue + * 1. Use StaticallyIndexedArray instead of C array for thread buffer + * 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor + * 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate + * + */ +template +struct ThreadGroupTensorSliceTransfer_v4r2 +{ + static constexpr index_t nDim = + remove_reference_t>::GetNumOfDimension(); + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v4r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + + { + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_assert(nDim == + remove_cvref_t>::GetNumOfDimension(), + "wrong! nDim not consistent"); + }); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffer& src_bufs, + const DstDescs& dst_descs, + DstBuffer& dst_bufs, + Number thread_scratch_id) + { + RunRead(src_descs, src_bufs, thread_scratch_id); + RunWrite(dst_descs, dst_bufs, thread_scratch_id); + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v3r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp new file mode 100755 index 000000000000..905a59f56e3b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,134 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp new file mode 100755 index 000000000000..83cb9fb5de4e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp @@ -0,0 +1,164 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r1r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1r2( + const SrcDesc& src_desc, + const Index& src_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.template Run( + src_desc, src_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrcSliceOrigin(src_desc, + src_block_slice_origin + thread_data_idx_begin); + } + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_block_slice_origin) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r1r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp new file mode 100755 index 000000000000..17110c8358db --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. It does not keep reference to tensor descriptor +// 3. Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src0_desc, src0_buf, src1_desc, src1_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp new file mode 100755 index 000000000000..9a5317dd126a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,183 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp" + +namespace ck { + +// this version does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray instead of C array for thread buffer +// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor +// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate +template +struct ThreadGroupTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + using Index = MultiIndex; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_block_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_block_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_block_slice_origin, + const DstDesc& dst_desc, + const Index& dst_block_slice_origin, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src0_desc, + make_zero_multi_index(), + src1_desc, + make_zero_multi_index(), + src2_desc, + make_zero_multi_index(), + dst_desc, + make_zero_multi_index(), + element_op) + + { + static_assert(nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == remove_cvref_t::GetNumOfDimension() && + nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + threadwise_transfer_.SetSrc0SliceOrigin( + src0_desc, src0_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc1SliceOrigin( + src1_desc, src1_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetSrc2SliceOrigin( + src2_desc, src2_block_slice_origin + thread_data_idx_begin); + threadwise_transfer_.SetDstSliceOrigin(dst_desc, + dst_block_slice_origin + thread_data_idx_begin); + } + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run( + src0_desc, src0_buf, src1_desc, src1_buf, src2_desc, src2_buf, dst_desc, dst_buf); + } + } + + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc0SliceWindow(src0_desc, step); + } + } + + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc1SliceWindow(src1_desc, step); + } + } + + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrc2SliceWindow(src2_desc, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_desc, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v6r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp new file mode 100755 index 000000000000..993d90e356de --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp @@ -0,0 +1,172 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags> +struct ThreadGroupTensorSliceTransfer_v7 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == DimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.Run(src_descs, src_bufs, dst_descs, dst_bufs); + } + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp new file mode 100755 index 000000000000..0a0bcbac38c2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags, + index_t NumThreadScratch = 1> +struct ThreadGroupTensorSliceTransfer_v7r2 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + using is_tuple = decltype(ck::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r2; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp new file mode 100755 index 000000000000..46d0c6ac2eb4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,220 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags, + index_t NumThreadScratch = 1> +struct ThreadGroupTensorSliceTransfer_v7r3 +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + + const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths; + + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + else + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r3; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp new file mode 100755 index 000000000000..bee0b01a74f7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -0,0 +1,241 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Thread-group level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// +// Does following things to avoid scratch memory issue +// 1. Pass tensor descritpors by reference (or tuple of references) +// 2. Does not keep reference to tensor descriptor +// 3. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename ThreadClusterLengths, + typename ThreadClusterArrangeOrder, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename ThreadTransferSrcResetCoordinateAfterRunFlags, + typename ThreadTransferDstResetCoordinateAfterRunFlags, + typename IndexType, + index_t ScatterDim = 1, + bool OutputScatter = true, + index_t ScatterWeightIdx = 3, + index_t NumThreadScratch = 1> +struct ThreadGroupTensorSliceTransfer_v7r3_scatter +{ + static constexpr index_t nDim = + remove_cvref_t>::GetNumOfDimension(); + + static constexpr index_t mod_num = + ThreadClusterLengths{}.At(Number<3>{}); // Dirty HACK FELIX, TODO fix + static constexpr index_t nSrc = remove_cvref_t::Size(); + static constexpr index_t nDst = remove_cvref_t::Size(); + + using Index = MultiIndex; + + static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{}; + static constexpr index_t scatter_num = thread_slice_lengths.At(Number{}); + + __device__ constexpr ThreadGroupTensorSliceTransfer_v7r3_scatter( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_block_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_block_slice_origins, + const ElementwiseOperation& element_op) + : threadwise_transfer_(src_descs, + StaticallyIndexedArray{}, + dst_descs, + StaticallyIndexedArray{}, + element_op) + { + static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && + nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && + nDst == DstDatas::Size() && nDst == DstDescs::Size() && + nDst == ThreadTransferDstResetCoordinateAfterRunFlags::Size(), + "wrong!"); + + static_for<0, nSrc, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + static_assert( + nDim == remove_cvref_t>::GetNumOfDimension(), + "wrong!"); + }); + + static_assert(nDim == ThreadClusterLengths::Size() && + nDim == ThreadClusterArrangeOrder::Size() && + nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(), + "wrong! nDim not consistent"); + + static_assert( + is_same{}, + "wrong! threads should be mapped to cover entire slicing window"); + + static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(), + "wrong! ThreadGroup::GetNumOfThread() too small"); + + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + const auto src_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId())); + const auto src_thread_slice_origins = generate_tuple( + [&](auto i) { + return src_block_slice_origins[i] + + src_thread_cluster_idx * thread_slice_lengths; + }, + Number{}); + + const auto dst_thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex( + make_multi_index(OutputScatter ? ThreadGroup::GetThreadId() % mod_num + : ThreadGroup::GetThreadId())); + const auto dst_thread_slice_origins = generate_tuple( + [&](auto i) { + return dst_block_slice_origins[i] + + dst_thread_cluster_idx * thread_slice_lengths; + }, + Number{}); + + threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins); + threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins); + } + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + } + } + + template + using is_tuple = decltype(std::declval().IsTuple()); + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + StaticallyIndexedArray& scatter_offsets, + Number thread_scratch_id = Number{}) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + if constexpr(is_detected::value) + threadwise_transfer_.RunWrite( + dst_descs, dst_bufs, scatter_offsets, thread_scratch_id); + else + threadwise_transfer_.RunWrite( + dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id); + } + } + + template + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs, + StaticallyIndexedArray& scatter_offsets) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs, scatter_offsets); + } + + template + __device__ void + MoveSrcSliceWindow(const SrcDescs& src_descs, Number iSrc, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveSrcSliceWindow(src_descs, iSrc, step); + } + } + + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step) + { + static_for<0, SrcDescs::Size(), 1>{}( + [&](auto i) { MoveSrcSliceWindow(src_descs, i, step); }); + } + + template + __device__ void + MoveDstSliceWindow(const DstDescs& dst_descs, Number iDst, const Index& step) + { + if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or + ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) + { + threadwise_transfer_.MoveDstSliceWindow(dst_descs, iDst, step); + } + } + + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step) + { + static_for<0, DstDescs::Size(), 1>{}( + [&](auto i) { MoveDstSliceWindow(dst_descs, i, step); }); + } + + private: + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + + using ThreadwiseTransfer = + ThreadwiseTensorSliceTransfer_v7r3_scatter; + + ThreadwiseTransfer threadwise_transfer_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp new file mode 100755 index 000000000000..dc08a2c88b2a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace conv_tensor_rearrange_op { + +struct BaseConvTensorRearrangeOp +{ +}; + +struct ImageToColumn : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Image to Column"; +}; + +struct ColumnToImage : public BaseConvTensorRearrangeOp +{ + static constexpr const char* name = "Column to Image"; +}; + +template ::value, + bool>::type = false> +std::ostream& operator<<(std::ostream& os, const BaseConvTensorRearrangeOp&) +{ + os << Op::name; + return os; +} + +} // namespace conv_tensor_rearrange_op +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp new file mode 100755 index 000000000000..cab5e213651d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardDataSpecialization +{ + Default, + Filter1x1Stride1Pad0, +}; + +inline std::string +getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecialization& s) +{ + switch(s) + { + case ConvolutionBackwardDataSpecialization::Default: return "Default"; + case ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp new file mode 100755 index 000000000000..01bb806789c3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionBackwardWeightSpecialization +{ + Default, + Filter1x1Stride1Pad0, + Filter1x1Pad0, + OddC, +}; + +inline std::string +getConvBackwardWeightSpecializationString(const ConvolutionBackwardWeightSpecialization& s) +{ + switch(s) + { + case ConvolutionBackwardWeightSpecialization::Default: return "Default"; + case ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0: + return "Filter1x1Stride1Pad0"; + case ConvolutionBackwardWeightSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionBackwardWeightSpecialization::OddC: return "OddC"; + default: return "Unrecognized specialization!"; + } +} +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp new file mode 100755 index 000000000000..cf20025d46e1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef CK_CODE_GEN_RTC +#include +#endif + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct ConvolutionForwardSpecialization +{ + Default, + Filter1x1Pad0, + Filter1x1Stride1Pad0, + OddC, + Filter3x3, +}; + +#ifndef CK_CODE_GEN_RTC +inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s) +{ + switch(s) + { + case ConvolutionForwardSpecialization::Default: return "Default"; + case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0"; + case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0"; + case ConvolutionForwardSpecialization::OddC: return "OddC"; + case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3"; + default: return "Unrecognized specialization!"; + } +} +#endif + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp new file mode 100755 index 000000000000..5a2b082c31aa --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceAvgPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_k_wos_lengths, + std::vector dout_n_k_wos_strides, + std::vector din_n_k_wos_length, + std::vector din_n_k_wos_strides, + std::vector window_k_c_xs_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_base.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_base.hpp new file mode 100755 index 000000000000..9285211519a4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#include +#include +#include + +#include "ck/stream_config.hpp" +#endif + +namespace ck { +namespace tensor_operation { +namespace device { + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#define GET_OBJECT_NAME_IMLP \ + std::optional GetObjectName() const override \ + { \ + std::string str = __PRETTY_FUNCTION__; \ + static std::regex obj_name_expr{" (.*)::GetObjectName"}; \ + std::smatch match; \ + if(!std::regex_search(str, match, obj_name_expr)) \ + { \ + return str; \ + } \ + return std::string(match[1]) + ';'; \ + } + +#define GET_TEMPLATE_INFO_IMPL \ + std::optional GetTemplateInfo() const override \ + { \ + std::string str = __PRETTY_FUNCTION__; \ + static std::regex template_expr{"\\[(.*)\\]"}; \ + std::smatch match; \ + if(!std::regex_search(str, match, template_expr)) \ + { \ + return std::nullopt; \ + } \ + return std::string(match[1]); \ + } + +#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL +#endif + +#ifndef CK_CODE_GEN_RTC +struct BaseArgument +{ + BaseArgument() = default; + BaseArgument(const BaseArgument&) = default; + BaseArgument& operator=(const BaseArgument&) = default; + + virtual ~BaseArgument() {} + + void* p_workspace_ = nullptr; +}; + +struct BaseInvoker +{ + BaseInvoker() = default; + BaseInvoker(const BaseInvoker&) = default; + BaseInvoker& operator=(const BaseInvoker&) = default; + + virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{}) + { + return float{0}; + } + + virtual ~BaseInvoker() {} +}; +#endif + +struct BaseOperator +{ + BaseOperator() = default; + BaseOperator(const BaseOperator&) = default; + BaseOperator& operator=(const BaseOperator&) = default; +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) + virtual bool IsSupportedArgument(const BaseArgument*) { return false; } + virtual std::string GetTypeString() const { return ""; } + + virtual std::string GetTypeIdName() const { return typeid(*this).name(); } + + virtual std::optional GetObjectName() const { return std::nullopt; } + + virtual std::optional GetTemplateInfo() const { return std::nullopt; } + + virtual std::string GetTypeIdHashCode() const + { + std::ostringstream oss; + + oss << std::hex << typeid(*this).hash_code(); + + return oss.str(); + }; + + virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } + + virtual void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const + { + assert(p_arg); + p_arg->p_workspace_ = p_workspace; + } +#endif + virtual ~BaseOperator() {} +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp new file mode 100755 index 000000000000..ee7af0117d17 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceBatchedContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp new file mode 100755 index 000000000000..fcb460829337 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, + ck::index_t Batch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +struct DeviceBatchedGemmV2BScale : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideScaleB, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB, + ck::index_t BatchStrideC, + ck::index_t BatchStrideScaleB, + const void* p_b_scale, + ck::index_t Batch, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; +}; + +template +using DeviceBatchedGemmPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp new file mode 100755 index 000000000000..acd779b2db34 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp @@ -0,0 +1,50 @@ +#pragma once +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct BatchedGemmEPermuteDesc +{ + ck::index_t G0_, G1_, M_, N_; + ck::index_t stride_G0_, stride_G1_, stride_M_, stride_N_; +}; + +template +struct DeviceBatchedGemmEPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp new file mode 100755 index 000000000000..91b4b6b91b6b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + ck::index_t StrideB1, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + ck::index_t BatchStrideB1, + ck::index_t BatchStrideC, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp new file mode 100755 index 000000000000..8fb4a71f556a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmMultiD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +struct DeviceBatchedGemmV2MultiD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsisiten NumDTensor"); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp new file mode 100755 index 000000000000..8234e29486b1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator +{ + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a0, + const void* p_b0, + std::array p_d0s, + const void* p_b1, + std::array p_d1s, + void* p_e1, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA0, + ck::index_t StrideB0, + std::array StrideD0s, + ck::index_t StrideB1, + std::array StrideD1s, + ck::index_t StrideE1, + ck::index_t BatchStrideA0, + ck::index_t BatchStrideB0, + std::array BatchStrideD0s, + ck::index_t BatchStrideB1, + std::array BatchStrideD1s, + ck::index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp new file mode 100755 index 000000000000..204b09cad406 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#ifndef __HIPCC_RTC__ +#include +#include +#endif + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template // TODO: enum for mask type +struct DeviceBatchedGemmSoftmaxGemm : public BaseOperator +{ +#ifndef __HIPCC_RTC__ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t O, + ck::index_t Batch, + ck::index_t StrideA, + ck::index_t StrideB0, + ck::index_t StrideB1, + ck::index_t StrideC, + ck::index_t BatchStrideA, + ck::index_t BatchStrideB0, + ck::index_t BatchStrideB1, + ck::index_t BatchStrideC, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp new file mode 100755 index 000000000000..be8105c96726 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmSoftmaxGemmPermute : public BaseOperator +{ + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ns_strides, + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumAcc1Bias> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + C0DEElementwiseOperation c0de_element_op, + B1ElementwiseOperation b1_element_op, + C1DEElementwiseOperation c1de_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp new file mode 100755 index 000000000000..2c0da692570b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormBwd : public BaseOperator +{ + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + virtual std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array dyStrides, + const std::array dxStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_dy, + const void* p_scale, + const void* p_savedMean, + const void* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + void* p_dx, + void* p_dscale, + void* p_dbias) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchNormBwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp new file mode 100755 index 000000000000..e3962e177ee8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwd : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* bnScale, + const void* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double exponentialAverageFactor, + void* resultRunningMean, + void* resultRunningVariance) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchNormFwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp new file mode 100755 index 000000000000..69103b6f4429 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormInfer : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* bnScale, + const void* bnBias, + double epsilon, + const YElementwiseOp y_elementwise_op, + const void* estimatedMean, + const void* estimatedInvVariance, + void* p_y) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceBatchNormInferPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_cgemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_cgemm.hpp new file mode 100755 index 000000000000..44dedeeef95c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_cgemm.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceCGemm : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + virtual std::size_t GetWorkspaceSize(index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC) const = 0; +}; + +template +using DeviceCGemmPtr = std::unique_ptr< + DeviceCGemm>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp new file mode 100755 index 000000000000..3116d404742e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M0, M1, ... K0, K1, ...], ... +// input : B0[N0, N1, ... K0, K1, ...], ... +// input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ... +// output : E[M0, M1, ... N0, N1, ...] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp new file mode 100755 index 000000000000..dffb51d0ba75 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_ms_ns_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp new file mode 100755 index 000000000000..eb1b85ec822a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(void* p_in, + const void* p_wei, + const void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp new file mode 100755 index 000000000000..4dc11dbefd73 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd.hpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp new file mode 100755 index 000000000000..7d3845666cf2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivation : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp new file mode 100755 index 000000000000..3a49ac632e7d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceConvFwdBiasActivationAdd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const void* p_bias, + const void* p_resi, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceConvFwdBiasActivationAddPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp new file mode 100755 index 000000000000..b28204fd841c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief Convolution Tensor Rearrange. + * + * This Device operator supports converting an image to + * the GEMM representation (Image to Column) and + * converting a GEMM form to the image (Column to Image). + * Supported layouts: + * [G, N, Di, Hi, Wi, C] <-> [G, N * Do * Ho * Wo, Z * Y * X * C] + * [N, Di, Hi, Wi, G, C] <-> [N * Do * Ho * Wo, G, Z * Y * X * C] + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ImageLayout Input Layout. + * \tparam InputDataType Input Data Type. + * \tparam OutputDataType Output Data Type. + * \tparam ConvTensorRearrangeOp Operation type: ImageToColumn, ColumnToImage. + */ +template +struct DeviceConvTensorRearrange : public BaseOperator +{ + + /** + * \brief Make argument pointer for image to column. + * + * \param p_in A pointer to the device memory of the input image. + * \param p_out A pointer to the device memory of the output. + * \param G Convolution number of groups. + * \param N Convolution batch size. + * \param C Convolution number of channels. + * \param input_spatial_lengths Input spatial lengths. + * \param filter_spatial_lengths Filter spatial lengths. + * \param output_spatial_lengths Output spatial lengths. + * \param image_g_n_c_wis_strides Image strides in order [G, N, C, D, H, W]. + * \param gemm_g_m_k_strides Gemm form strides. + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Convolution left pads. + * \param input_right_pads Convolution right pads. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_out, + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise.hpp new file mode 100755 index 000000000000..db0e4bd83f46 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwise : public BaseOperator +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; // namespace device + +template +using DeviceElementwisePtr = std::unique_ptr< + DeviceElementwise>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp new file mode 100755 index 000000000000..c56a947ec9e1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseNormalization : public BaseOperator +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::array, NumInput> inStridesArray, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + double epsilon, + const std::array in_dev_buffers, + const void* p_gamma, + const void* p_beta, + void* p_y, + XElementwiseOperation x_elementwise_op, + YElementwiseOperation y_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceElementwiseNormalizationPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp new file mode 100755 index 000000000000..ac6ff0a96026 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \note This structure is deprecated (left for backwards compatibility). Please use + * DeviceElementwise from device_elementwise.hpp. + */ +template +struct DeviceElementwise : public BaseOperator +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; // namespace device + +template +using DeviceElementwisePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm.hpp new file mode 100755 index 000000000000..adf909821dbb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm.hpp @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp new file mode 100755 index 000000000000..a7f42c3b35e2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_bias_e_permute.hpp @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct DEGridDesc_M0_M1_M2_N0_N1 +{ + ck::index_t M0_, M1_, M2_, N0_, N1_; + ck::index_t stride_M0_, stride_M1_, stride_M2_, stride_N0_, stride_N1_; +}; + +// input : A[M, K], B[K, N], +// input : D[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D) +template +struct DeviceGemmBiasCPermute : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_d, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + DEGridDesc_M0_M1_M2_N0_N1 d_gride_desc, + DEGridDesc_M0_M1_M2_N0_N1 e_gride_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp new file mode 100755 index 000000000000..acb18efabfe2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Dequantization of input tensor could not be decoupled from gridwisegemm pipeline +// As input tensor thread buffer declared inside blockwise-gemm pipeline. + +template +struct DeviceGemm_dequantB : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp new file mode 100755 index 000000000000..cbb9fadc6d6f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A0[M, K], B0[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp new file mode 100755 index 000000000000..6ebdbc5054d3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp @@ -0,0 +1,200 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#ifndef __HIPCC_RTC__ +#include +#endif + +#include "ck/utility/array.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif +}; + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif +}; + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef CK_CODE_GEN_RTC + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +#endif +}; + +template +struct DeviceMoEGemmMXBPreShuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef CK_CODE_GEN_RTC + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_e, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + std::array StrideDs, + ck::index_t StrideE, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +#endif +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp new file mode 100755 index 000000000000..abf49bdab2c5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp @@ -0,0 +1,108 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_ABScale : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp new file mode 100755 index 000000000000..0258858fe50e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : H[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// H = layernorm(E) +// Assume: +// D0, D1, ... and E have the same layout +// Calculate mean & variance along N dimension in layernorm(E) +template +struct DeviceGemmMultipleDLayernorm : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp new file mode 100755 index 000000000000..539e83f7cb8a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// FIXME: DeviceGemmReduce type need to well define the problem +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : R0[M], R1[M], ... +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ... +// R0 = r_op0(Q0), R1 = r_op1(Q1), ... +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDMultipleR : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::array p_rs, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array StrideDs, + ck::index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmMultipleDMultipleRPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp new file mode 100755 index 000000000000..0562e452ac55 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp @@ -0,0 +1,88 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMX : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +struct DeviceGemmMX_BPreshuffle : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp new file mode 100755 index 000000000000..eaa7671c6424 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_reduce.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// FIXME: DeviceGemmReduce type need to well define the problem +template +struct DeviceGemmReduce : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_ops, + std::array reduce_out_element_ops, + ck::index_t BatchCount = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmReducePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp new file mode 100755 index 000000000000..dfc27e726eaa --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_splitk.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmSplitK : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmSplitKPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp new file mode 100755 index 000000000000..ed081ad7fc80 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmStreamK : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t NumSKBlocks = 0) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceGemmStreamKPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp new file mode 100755 index 000000000000..ad79c1f61c95 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Streamk_V2 : public BaseOperator +{ + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t Streamk_sel, + ck::index_t Grid_size, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp new file mode 100755 index 000000000000..b251fb97b985 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmV2 : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteA() = 0; + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; +}; + +template +struct DeviceGemmV2R1 : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + std::array DsStrides, + ck::index_t StrideC, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +struct DeviceGemmV2BScale : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t StrideScaleB, + const void* p_b_scale, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; +}; + +template +struct DeviceGemmV2BPreshuffle : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + ck::index_t KSplit, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual bool GetPermuteA() = 0; + virtual bool GetPermuteB() = 0; + virtual ck::index_t GetKPerBlock() = 0; + virtual int GetPreShuffleParameters() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp new file mode 100755 index 000000000000..ba81948440ac --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct ContractionDesc +{ + std::vector a_ms_ks_lengths; + std::vector a_ms_ks_strides; + + std::vector b_ns_ks_lengths; + std::vector b_ns_ks_strides; + + std::array, NumDTensor> ds_ms_ns_lengths; + std::array, NumDTensor> ds_ms_ns_strides; + + std::vector e_ms_ns_lengths; + std::vector e_ms_ns_strides; +}; + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceGroupedContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp new file mode 100755 index 000000000000..9c44bda5cac8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A[G, N, K, Ho, Wo] +// input : weight B[G, K, C, Y, X], +// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... +// output : input image E[G, N, C, Hi, Wi], +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const ck::index_t split_k = 1) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp new file mode 100755 index 000000000000..7296e4faaa30 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedConvBwdWeight : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, + void* p_wei, + const void* p_out, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp new file mode 100755 index 000000000000..6febf702f9e6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedConvBwdWeightMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsLayout::Size(); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp new file mode 100755 index 000000000000..025c43e75cc2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Convolution Forward: +// input : input image A[G, N, C, Hi, Wi], +// input : weight B[G, K, C, Y, X], +// output : output image E[G, N, K, Ho, Wo] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + const void* p_wei, // weight + void* p_out, // output image + const std::array& in_g_n_c_wis_lengths, + const std::array& in_g_n_c_wis_strides, + const std::array& wei_g_k_c_xs_lengths, + const std::array& wei_g_k_c_xs_strides, + const std::array& out_g_n_k_wos_lengths, + const std::array& out_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const InElementwiseOperation& in_element_op, + const WeiElementwiseOperation& wei_element_op, + const OutElementwiseOperation& out_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp new file mode 100755 index 000000000000..8c9b768a8b9c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp @@ -0,0 +1,171 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef CK_CODE_GEN_RTC +#include +#endif + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else +template +using is_tuple = decltype(std::declval().IsTuple()); +#endif + +/** + * \brief Grouped Convolution Forward + * + * \details + * input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]... + * input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]... + * input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... + * output : output image E[G, N, K, Ho, Wo] + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ALayout Input layout (also for a1, a2...). + * \tparam BLayout Weight layout (also for b1, b2...). + * \tparam DsLayout Ds layouts. + * \tparam ELayout Output layout. + * \tparam ADataType Input data type. Pass tuple if there is multiple A. + * \tparam BDataType Weight data type. Pass tuple if there is multiple B. + * \tparam DsDataType D data types. + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation CDE elementwise operation. + * \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed). + * \tparam BComputeType Compute data type for B tensor (default: AComputeType). + */ +template ::value, + Number<0>, + ADataType>()), // AComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeType = AComputeType> +struct DeviceGroupedConvFwdMultipleABD : public BaseOperator +{ + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor"); +#ifdef CK_CODE_GEN_RTC + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; +#else + // If DataType is tuple, user has to pass std::array with pointers. + using APointers = + ck::conditional_t&, const void*>; + using BPointers = + ck::conditional_t&, const void*>; +#endif + +#ifndef CK_CODE_GEN_RTC + + /** + * \brief Make argument pointer for grouped conv fwd. + * + * \param p_a A pointer to the input (std::array with + pointers for multiple A). + * \param p_b A pointer to the weight (std::array with + pointers for multiple B). + * \param p_ds A pointers to the Ds. + * \param p_e A pointers to the output. + * \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d). + * \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d). + * \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d). + * \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d). + * \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d). + * \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d). + * \param conv_filter_strides Convolution filter strides. + * \param conv_filter_dilations Convolution filter dilations. + * \param input_left_pads Input left paddings. + * \param input_right_pads Input right paddings. + * \param a_element_op A elementwise operation object. + * \param b_element_op B elementwise operation object. + * \param cde_element_op CDE elementwise operation object. + * \return Pointer to the argument. + */ + virtual std::unique_ptr MakeArgumentPointer( + APointers p_a, + BPointers p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr + MakeArgumentPointer(APointers p_a, + BPointers p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +#endif +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp new file mode 100755 index 000000000000..daf397189a4e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief Grouped Convolution Forward + * + * \note This structure is deprecated (left for backwards compatibility). Please use + * DeviceGroupedConvFwdMultipleABD. + * + * \tparam NDimSpatial Number of spatial dimensions. + * \tparam ALayout Input layout (also for a1, a2...). + * \tparam BLayout Weight layout (also for b1, b2...). + * \tparam DsLayout Ds layouts. + * \tparam ELayout Output layout. + * \tparam ADataType Input data type. Pass tuple if there is multiple A. + * \tparam BDataType Weight data type. Pass tuple if there is multiple B. + * \tparam DsDataType D data types. + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation CDE elementwise operation. + * \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). + */ +template ::value, + Number<0>, + ADataType>())> // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed +using DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp new file mode 100755 index 000000000000..267a970ee5e6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp @@ -0,0 +1,185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "device_base.hpp" +#include "ck/utility/ignore.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Structure representing single GEMM problem arguments. +/// +/// The pointer to the vector of those structures is passed to the GroupedGEMM entry +/// point kernel. +/// +/// @tparam NumDTensor The number of D input tensors. +/// +template +struct GroupedGemmKernelArgument +{ + __host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_, + const void* p_b_grid_, + std::array p_ds_grid_, + void* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_) + : p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{p_ds_grid_}, + p_e_grid{p_e_grid_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideE{StrideE_} + { + } + + const void* p_a_grid; + const void* p_b_grid; + std::array p_ds_grid; + void* p_e_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideE; + + void Print() const + { + std::stringstream str; + for(auto sd : StrideDs) + str << sd << ","; + + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SE:" << StrideE << ", " + << "SDs: {" << str.str() << "}" + << "}" << std::endl; + } +}; + +struct GemmDesc +{ + ck::index_t M_, N_, K_; + ck::index_t stride_A_, stride_B_, stride_C_; + + std::vector stride_Ds_; +}; + +template +struct DeviceGroupedGemm : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor"); + + virtual std::unique_ptr + MakeArgumentPointer(std::vector& p_a, + std::vector& p_b, + std::vector>& p_ds, + std::vector& p_e, + std::vector& gemm_desc, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + //--------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer and may copy data to device. + /// + /// TODO: Add which kernels are using this (TileLoop * FixedNK ??) + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which will contain kernel + /// arguments. + /// @param[in] p_host_kernel_args The pointer to the host memory which contains kernel + /// arguments that should be copied to device memory. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + ignore = p_arg; + ignore = p_dev_kernel_args; + ignore = p_host_kernel_args; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the device kernel arguments pointer and may copy data to device. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel + /// arguments. + /// + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const + { + ignore = p_arg; + ignore = p_dev_kernel_args; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Gets the device kernel argument size. + /// + /// @param[in] p_arg The pointer to the Device op Argument. + /// + /// @return The device kernel argument size. + /// + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const + { + ignore = p_arg; + + std::ostringstream err; + err << "This function is not implemented by the kernel: " << this->GetTypeString() + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp new file mode 100755 index 000000000000..780a0c30c50f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "device_grouped_gemm_splitk.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedGemmFixedNK : DeviceGroupedGemmSplitK +{ +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp new file mode 100755 index 000000000000..59483cb89fa6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp @@ -0,0 +1,98 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +struct GemmMultiABDDesc +{ + ck::index_t M_, N_, K_; + + std::vector stride_As_; + std::vector stride_Bs_; + std::vector stride_Ds_; + + ck::index_t stride_C_; +}; + +/* + * \brief Grouped Gemm Multi ABD + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam AsLayout A layouts (tuple). + * \tparam BsLayout B layouts (tuple). + * \tparam DsLayout Ds layouts (tuple). + * \tparam ELayout Output layout. + * \tparam AsDataType A data types (tuple). + * \tparam BsDataType B data types (tuple). + * \tparam DsDataType D data types (tuple). + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation C elementwise operation. + */ +template +struct DeviceGroupedGemmMultiABD : public BaseOperator +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor"); + static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor"); + static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor"); + + /* + * \brief Make argument pointer for grouped gemm multi abd. + * + * \param p_as A pointers to the A. + * \param p_bs A pointers to the B. + * \param p_ds A pointers to the Ds. + * \param p_e A pointers to the E. + * \param gemm_desc Gemm descriptors for each group. + * \param a_element_op A elementwise operation object. + * \param b_element_op B elementwise operation object. + * \param cde_element_op CDE elementwise operation object. + * \return Pointer to the argument. + */ + virtual std::unique_ptr + MakeArgumentPointer(std::vector>& p_as, + std::vector>& p_bs, + std::vector>& p_ds, + std::vector& p_e, + std::vector& gemm_desc, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp new file mode 100755 index 000000000000..05c3e796a8b7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_grouped_gemm_multi_abd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GroupedGemmMultiABDKernelArgument +{ + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + void* p_e_grid; + + index_t M; + index_t N; + index_t K; + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + index_t StrideE; +}; + +/* + * \brief Grouped Gemm Multi ABD Fixed NK + * + * C = a_op(A, A1...) * b_op(B, B1...) + * E = cde_op(C, D0, D1, ...) + * + * \tparam AsLayout A layouts (tuple). + * \tparam BsLayout B layouts (tuple). + * \tparam DsLayout Ds layouts (tuple). + * \tparam ELayout Output layout. + * \tparam AsDataType A data types (tuple). + * \tparam BsDataType B data types (tuple). + * \tparam DsDataType D data types (tuple). + * \tparam EDataType Output data type. + * \tparam AElementwiseOperation A elementwise operation. + * \tparam BElementwiseOperation B elementwise operation. + * \tparam CDEElementwiseOperation C elementwise operation. + */ +template +struct DeviceGroupedGemmMultiABDFixedNK : DeviceGroupedGemmMultiABD +{ + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const = 0; + virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0; + virtual void SetKBatch(BaseArgument* p_arg, index_t k_batch) const = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp new file mode 100755 index 000000000000..fae65097407c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp @@ -0,0 +1,75 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "device_base.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedGemmSoftmaxGemmPermute : public BaseOperator +{ + struct ProblemDesc + { + std::vector a_gs_ms_ks_lengths; + std::vector a_gs_ms_ks_strides; + + std::vector b0_gs_ns_ks_lengths; + std::vector b0_gs_ns_ks_strides; + + std::vector b1_gs_os_ns_lengths; + std::vector b1_gs_os_ns_strides; + + std::vector c_gs_ms_os_lengths; + std::vector c_gs_ms_os_strides; + + std::vector> acc0_biases_gs_ms_ns_lengths; + std::vector> acc0_biases_gs_ms_ns_strides; + + std::vector> acc1_biases_gs_ms_os_lengths; + std::vector> acc1_biases_gs_ms_os_strides; + }; + + virtual std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b0_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + Acc0ElementwiseOperation acc0_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp new file mode 100755 index 000000000000..3ea650190271 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp @@ -0,0 +1,55 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm +{ + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0; + //---------------------------------------------------------------------------------------------- + /// @brief Sets the k batch size. + /// + /// @param p_arg Pointer to the Argument we're going to change. + /// @param[in] kbatch The kbatch value. + /// + virtual void SetKBatch(BaseArgument* p_arg, index_t kbatch) const + { + this->SetKBatchSize(p_arg, kbatch); + }; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp new file mode 100755 index 000000000000..712fbfd9e9ac --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "device_grouped_gemm.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief Grouped GEMM kernel using output Tile Looping algorithm +/// +/// @par This kernel does not require any knowledge about input data sizes (GEMM M/N/K) +/// It requires only the number of groups to launch. Other information like +/// data pointers and GEMM sizes, packed into gemm kernel args may be all dynamic +/// (known only at kernel run-time). +/// +/// @note This kernel does not support SplitK. + +template +struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm +{ +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp new file mode 100755 index 000000000000..5a4a9cac1e8e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// For pooling which used indexable operation, such as MaxPool, MinPool...etc +template +struct DeviceMaxPoolBwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp new file mode 100755 index 000000000000..f68022ca04de --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMultipleReduce : public BaseOperator +{ + static constexpr index_t NumInputDim = Rank; + static constexpr index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1; + + virtual std::unique_ptr MakeArgumentPointer( + const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array, NumReduction> outStrides, + const std::array reduceDims, + const std::array alphas, + const std::array betas, + const void* in_dev, + const std::array out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceMultipleReducePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp new file mode 100755 index 000000000000..327acfeb53e9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdData : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdDataPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp new file mode 100755 index 000000000000..5a8804cbf0e8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,61 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationBwdGammaBeta : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationBwdGammaBetaPtr = + std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp new file mode 100755 index 000000000000..d252ad1d98f5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +template +struct DeviceNormalizationFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_savedMean, + void* p_savedInvVar, + YElementwiseOperation y_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceNormalizationFwdPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_permute.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_permute.hpp new file mode 100755 index 000000000000..c994cf02c6aa --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_permute.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePermute : BaseOperator +{ + using Lengths = std::array; + using Strides = Lengths; + + virtual std::unique_ptr + MakeArgumentPointer(const Lengths& in_lengths, + const Strides& in_strides, + const Lengths& out_lengths, + const Strides& out_strides, + const void* in_dev_buffer, + void* out_dev_buffer, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp new file mode 100755 index 000000000000..62071c43a1a0 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePoolFwd : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + std::vector input_n_c_wis_lengths, + std::vector window_xs_lengths, + std::vector output_n_c_wos_lengths, + std::vector input_n_c_wis_stride, + std::vector output_n_c_wis_stride, + std::vector indices_n_c_wis_stride, + std::vector window_xs_strides, + std::vector window_xs_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector pooling_dims) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_put_element.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_put_element.hpp new file mode 100755 index 000000000000..17df2de37b68 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_put_element.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/reduction_enums.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElement : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t output_length, + ElementwiseOperation elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_reduce.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_reduce.hpp new file mode 100755 index 000000000000..c2721b18455c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_reduce.hpp @@ -0,0 +1,71 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduce : public BaseOperator +{ + static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + virtual std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + double alpha, + double beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceReducePtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp new file mode 100755 index 000000000000..67286e5541e8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceMultiD : public BaseOperator +{ + static constexpr index_t NumOutDim = (Rank - NumReduceDim == 0) ? 1 : Rank - NumReduceDim; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array, NumDTensor> DsLengths, + const std::array, NumDTensor> DsStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + const void* in_dev, + const std::array ds_dev, + void* out_dev, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceReduceMultiDPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_softmax.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_softmax.hpp new file mode 100755 index 000000000000..1902fd09eec1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_softmax.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSoftmax : public BaseOperator +{ + // + // @brief Makes a pointer to Argument class. + // + // @param[in] inLengths Input tensor extent(s) from high to low dimension + // @param[in] inStrides Input tensor stride(s) from high to low dimension + // @param[in] reduceDims The dimension(s) the normalization operation is applied + // @param[in] alpha double type value + // @param[in] beta double type value + // @param[in] in_dev Typeless const pointer in device memory storing the input + // tensor + // @param out_dev Typeless pointer in device memory storing the output tensor + // @param[in] in_elementwise_op The input elementwise operation. + // @param[in] acc_elementwise_op The accumulation elementwise operation. + // + // @return Unique pointer to the Argument class. + // + virtual std::unique_ptr + MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + double alpha, + double beta, + const void* in_dev, + void* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +template +using DeviceSoftmaxPtr = std::unique_ptr>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp new file mode 100755 index 000000000000..eeccd977ccbf --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceSplitKContractionMultipleD : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp new file mode 100755 index 000000000000..8824f44ec5f1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/gemm_specialization.hpp @@ -0,0 +1,59 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct GemmSpecialization +{ + // Gemm + Default, + MPadding, + NPadding, + KPadding, + MNPadding, + MKPadding, + NKPadding, + MNKPadding, + // Gemm + Gemm + OPadding, + MOPadding, + NOPadding, + KOPadding, + MNOPadding, + MKOPadding, + NKOPadding, + MNKOPadding, +}; +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::string getGemmSpecializationString(const GemmSpecialization& s) +{ + switch(s) + { + case GemmSpecialization::Default: return "Default"; + case GemmSpecialization::MPadding: return "MPadding"; + case GemmSpecialization::NPadding: return "NPadding"; + case GemmSpecialization::KPadding: return "KPadding"; + case GemmSpecialization::MNPadding: return "MNPadding"; + case GemmSpecialization::MKPadding: return "MKPadding"; + case GemmSpecialization::NKPadding: return "NKPadding"; + case GemmSpecialization::MNKPadding: return "MNKPadding"; + case GemmSpecialization::OPadding: return "OPadding"; + case GemmSpecialization::MOPadding: return "MOPadding"; + case GemmSpecialization::NOPadding: return "NOPadding"; + case GemmSpecialization::KOPadding: return "KOPadding"; + case GemmSpecialization::MNOPadding: return "MNOPadding"; + case GemmSpecialization::MKOPadding: return "MKOPadding"; + case GemmSpecialization::NKOPadding: return "NKOPadding"; + case GemmSpecialization::MNKOPadding: return "MNKOPadding"; + default: return "Unrecognized specialization!"; + } +} +#endif + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/helper.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/helper.hpp new file mode 100755 index 000000000000..c0e5ce9dd3ca --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/helper.hpp @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include +#include + +// functions to return the corresponding structs based on generated template parameters + +using layouts = std::variant; +// return the layout type: currently this is the only type supported in MIOpen +auto layout_type(std::string type) +{ + if(type == "ck::tensor_layout::convolution::NHWGK") + { + return ck::tensor_layout::convolution::NHWGK{}; + } + throw std::runtime_error("Incorrect layout"); +} +// return the right gemm spec based on the generated template parameters +ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type) +{ + if(type == "ck::tensor_operation::device::GemmSpecialization::Default") + { + return ck::tensor_operation::device::GemmSpecialization::Default; + } + if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding") + { + return ck::tensor_operation::device::GemmSpecialization::MNKPadding; + } + throw std::runtime_error("Incorrect gemm spec: " + type); +} + +// return the type of convolution +ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type) +{ + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0; + } + if(type == + "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0; + } + if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC") + { + return ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC; + } + throw std::runtime_error("Incorrect conv spec: " + type); +} + +// Function to call on MatrixPadder via a wrapper struct +// NOTE: CK only uses MNKPadding for forward convolution +template +auto pad(ck::index_t mpb, + ck::index_t npb, + ck::index_t kpb, + ck::tensor_operation::device::GemmSpecialization gemm, + CDesc_MRaw_NRaw conv) +{ + if(gemm == ck::tensor_operation::device::GemmSpecialization::MNKPadding) + { + ck::tensor_operation::device::MatrixPadder< + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + ck::index_t, + ck::index_t, + ck::index_t> + a; + a.MPerTile_ = mpb; + a.NPerTile_ = npb; + a.KPerTile_ = kpb; + auto tmp = grid_desc(a, conv); + return tmp; + } + throw std::runtime_error("Incorrect template parameters, check gemm spec"); +} + +// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num +// dims +// FIXME: add a way to properly pass in the layout +auto transform_conv(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 2 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 2, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_3d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 3 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 3, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + throw std::runtime_error("Incorrect conv spec"); +} + +auto transform_conv_1d(ck::index_t num_dim, + ck::tensor_operation::device::ConvolutionForwardSpecialization spec, + ck::Array out_lengths, + ck::Array out_strides) +{ + ck::Array dummy_dims; + ck::Array dummy_spatial_dims; + + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 1 && + spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC) + { + ck::tensor_operation::TransformConvFwdToGemm< + 1, + ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC> + conv_fwd{dummy_dims, + dummy_dims, + dummy_dims, + dummy_dims, + out_lengths, + out_strides, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims, + dummy_spatial_dims}; + + auto res = ck::tensor_operation::TransformConv(); + return res.transform_func(conv_fwd); + } + throw std::runtime_error("Incorrect dims or conv spec"); +} + +template +auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder) +{ + if(m_per_block == 32 && n_per_block == 64) + { + auto b2e = ck::BlockToCTileMap_M00_N0_M01Adapt<32, 64, CGridDesc_M_N>(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 32 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<32, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 64 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<64, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 32) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 32, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 64) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 64, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 128 && n_per_block == 256) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<128, 256, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + if(m_per_block == 256 && n_per_block == 128) + { + ck::BlockToCTileMap_M00_N0_M01Adapt<256, 128, CGridDesc_M_N> b2e(matrix_padder); + return b2e.CalculateGridSize(matrix_padder); + } + throw std::runtime_error("Incorrect template parameters"); +} + +// wrapper functions by dims to get grid size - uses above 3 functions +// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions +auto get_launch_params_1d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} + +auto get_launch_params_3d(ck::host::Solution solution, + ck::Array out_lengths, + ck::Array out_strides) +{ + auto num_dim = solution.GetTemplateParameter("NumDim"); + auto m_per_block = solution.GetTemplateParameter("MPerBlock"); + auto n_per_block = solution.GetTemplateParameter("NPerBlock"); + auto k_per_block = solution.GetTemplateParameter("KPerBlock"); + auto GemmType = solution.GetTemplateParameter("GemmSpecialization"); + auto ConvType = solution.GetTemplateParameter("ConvSpecialization"); + ck::tensor_operation::device::GemmSpecialization GemmSpec = gemm_type(GemmType); + ck::tensor_operation::device::ConvolutionForwardSpecialization ConvSpec = conv_type(ConvType); + auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides); + auto matrix_padder = + pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer); + auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder); + return b2e; +} diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100755 index 000000000000..72c011bfb25a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1040 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef CK_CODE_GEN_RTC +#include +#include +#include +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; }); + + const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + + GridwiseGemm::template Run( + p_as_grid + a_batch_offset, + p_bs_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ + + device_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AsPointer, // tuples if multi AB, pointers if no + BsPointer, + DsPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfBatch, + HasMainKBlockLoop, + isMultiA, + isMultiB>(p_as_grid, + p_bs_grid, + p_ds_grid, + *p_e_grid, + a_element_op, + b_element_op, + cde_element_op, + batch_count, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + compute_ptr_offset_of_batch); +} + +} // namespace + +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else +template +using is_tuple = decltype(std::declval().IsTuple()); +#endif + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = ck::conditional_t, ADataType>; + using GemmBDataType = ck::conditional_t, BDataType>; + +#define GridwiseGemmMultiABDTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + // Use appropriate gridwise gemm + using GridwiseGemm = ck::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; + + // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. + using APointers = ck::conditional_t&, const void*>; + using BPointers = ck::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument + { + __device__ __host__ Argument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0]; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_batch_ for multiple AB + compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0]; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + ck::Array a_g_n_c_wis_lengths_; + ck::Array a_g_n_c_wis_strides_; + ck::Array b_g_k_c_xs_lengths_; + ck::Array b_g_k_c_xs_strides_; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_; + ck::Array e_g_n_k_wos_lengths_; + ck::Array e_g_n_k_wos_strides_; + ck::Array conv_filter_strides_; + ck::Array conv_filter_dilations_; + ck::Array input_left_pads_; + ck::Array input_right_pads_; + }; + + static __device__ __host__ bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + valid = false; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + valid = false; + } + } + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static __device__ __host__ auto MakeArgument( + APointers p_as, + BPointers p_bs, + const ck::Array& p_ds, + void* p_e, + const ck::Array& a_g_n_c_wis_lengths, + const ck::Array& a_g_n_c_wis_strides, + const ck::Array& b_g_k_c_xs_lengths, + const ck::Array& b_g_k_c_xs_strides, + const ck::Array, NumDTensor>& ds_g_n_k_wos_lengths, + const ck::Array, NumDTensor>& ds_g_n_k_wos_strides, + const ck::Array& e_g_n_k_wos_lengths, + const ck::Array& e_g_n_k_wos_strides, + const ck::Array& conv_filter_strides, + const ck::Array& conv_filter_dilations, + const ck::Array& input_left_pads, + const ck::Array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + ck::Array a_g_n_c_wis_lengths_i32; + ck::Array a_g_n_c_wis_strides_i32; + ck::Array b_g_k_c_xs_lengths_i32; + ck::Array b_g_k_c_xs_strides_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_lengths_i32; + ck::Array, NumDTensor> ds_g_n_k_wos_strides_i32; + ck::Array e_g_n_k_wos_lengths_i32; + ck::Array e_g_n_k_wos_strides_i32; + ck::Array conv_filter_strides_i32; + ck::Array conv_filter_dilations_i32; + ck::Array input_left_pads_i32; + ck::Array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp new file mode 100755 index 000000000000..7fca3e29886e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp @@ -0,0 +1,523 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// In and Din = [N, C, Hi, Wi] +// Out and Dout = [N, C, Ho, Wo] +// Out = AvgPool2dFwd(In) +// Din = AvgPool2dBwd(Dout) +// Pooling dimension = H, W +template +struct DeviceAvgPool2dBwd_NHWC_NHWC : public DeviceAvgPoolBwd<2, + DOutDataType, + DInDataType, + tensor_layout::convolution::NHWC, + tensor_layout::convolution::NHWC> +{ + + static constexpr ck::index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto + Make2DGridDescriptor_Out_M_K_In_M(const std::vector& dout_n_c_wos_lengths, + const std::vector& din_n_c_wos_length, + const std::vector& dout_n_c_wos_strides, + const std::vector& din_n_c_wos_strides, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads, + const std::vector& tildes) + { + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t N = dout_n_c_wos_lengths[0]; + const index_t C = dout_n_c_wos_lengths[1]; + const index_t Ho = dout_n_c_wos_lengths[2]; + const index_t Wo = dout_n_c_wos_lengths[3]; + + const index_t Hi = din_n_c_wos_length[2]; + const index_t Wi = din_n_c_wos_length[3]; + + const index_t Y = window_lengths[0]; + const index_t X = window_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = window_strides[0]; + const index_t ConvStrideW = window_strides[1]; + + const index_t ConvDilationH = window_dilations[0]; + const index_t ConvDilationW = window_dilations[1]; + + const index_t Ni_stride = dout_n_c_wos_strides[0]; + const index_t Ci_stride = dout_n_c_wos_strides[1]; + const index_t Ho_stride = dout_n_c_wos_strides[2]; + const index_t Wo_stride = dout_n_c_wos_strides[3]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on Tildes that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // ReduceK is different for each Reduce + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // Problem size of reduction kernel + const index_t MRaw = N * HTildeSlice * WTildeSlice * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = YDotSlice * XDotSlice; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + const auto out_n_ho_wo_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Ho, Wo, C), make_tuple(Ni_stride, Ho_stride, Wo_stride, Ci_stride)); + + // Out[ReduceM, ReduceK] + const auto out_n_hop_wop_c_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_c_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C)), + make_merge_transform(make_tuple(YDotSlice, XDotSlice))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor( + out_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // In[ReduceM] + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(din_n_c_wos_strides[0], + din_n_c_wos_strides[2], + din_n_c_wos_strides[3], + din_n_c_wos_strides[1])); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_grid_desc_reducemraw = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto in_grid_desc_reducem = + transform_tensor_descriptor(in_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem); + } + + using DoutDinGridDesc = decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0, 0, 0}, + {0, 0}, + {0, 0}, + {0, 0}, + {0, 0}, + {0, 0}, + {0, 0})); + + using DoutGridDesc_M_K = remove_cvref_t>; + using DinGridDesc_M = remove_cvref_t>; + + // FIXME + // for NHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Div = tensor_operation::element_wise::UnaryDivide; + + using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + DInDataType* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_dout_grid_{p_dout}, + p_din_grid_{p_din}, + dout_n_c_wos_lengths_{dout_n_c_wos_lengths}, + din_n_c_wos_length_{din_n_c_wos_length}, + dout_n_c_wos_strides_{dout_n_c_wos_strides}, + din_n_c_wos_strides_{din_n_c_wos_strides}, + num_reduce_{1}, + div_element_op_{window_lengths[0] * window_lengths[1]} + { + std::vector Tildes(NDimSpatial); + for(int i = 0; i < NDimSpatial; ++i) + { + int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]); + Tildes[i] = window_strides[i] / GcdStrideDilation; + num_reduce_ *= Tildes[i]; + } + + for(index_t i_ytilde = 0; i_ytilde < Tildes[0]; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < Tildes[1]; ++i_xtilde) + { + const auto YDotSlice = + math::integer_divide_ceil(window_lengths[0] - i_ytilde, Tildes[0]); + const auto XDotSlice = + math::integer_divide_ceil(window_lengths[1] - i_xtilde, Tildes[1]); + + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto dout_din_grid_desc = + Make2DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {i_ytilde, i_xtilde}); + + dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]); + din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]); + } + } + } + + const DOutDataType* p_dout_grid_; + DInDataType* p_din_grid_; + std::vector dout_n_c_wos_lengths_; + std::vector din_n_c_wos_length_; + std::vector dout_n_c_wos_strides_; + std::vector din_n_c_wos_strides_; + + int num_reduce_; + std::vector dout_grid_desc_m_k_container_; + std::vector din_grid_desc_m_container_; + + Div div_element_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + for(index_t i = 0; i < arg.num_reduce_; i++) + { + const auto kernel = kernel_reduce_threadwise; + + ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0); + const index_t grid_size = (M / M_BlockTileSize); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.dout_grid_desc_m_k_container_[i], + arg.din_grid_desc_m_container_[i], + PassThrough{}, + arg.div_element_op_, + float(1), + arg.p_dout_grid_, + nullptr, + float(0), + arg.p_din_grid_, + nullptr); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr index_t Rank = NDimSpatial + 2; + int doutFastestDim = -1; + int dinFastestDim = -1; + + for(int i = 0; i < Rank; ++i) + { + if(arg.dout_n_c_wos_strides_[i] == 1) + doutFastestDim = i; + if(arg.din_n_c_wos_strides_[i] == 1) + dinFastestDim = i; + } + if(InSrcOutDstVectorSize != 1 && (dinFastestDim != 1 || doutFastestDim != 1)) + { + return false; + } + if(doutFastestDim == -1 || dinFastestDim == -1) + { + if constexpr(InSrcOutDstVectorSize != 1) + return false; + } + else + { + if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0) + return false; + if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0) + return false; + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) override + { + constexpr index_t Rank = NDimSpatial + 2; + + if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank || + dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank) + { + throw std::runtime_error("dimension of [dout|din]_n_c_wos_strides or " + "[dout|din]_n_c_wos_lengths is not equal to Rank"); + } + + if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial || + input_right_pads.size() != NDimSpatial) + { + throw std::runtime_error( + "dimension of [window_lengths, window_strides, window_dilations, input_left_pads, " + "input_right_pads] is not equal to Rank"); + } + return std::make_unique(static_cast(p_dout), + static_cast(p_din), + dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceAvgPool2dBwd<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp new file mode 100755 index 000000000000..3a2280a75c27 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp @@ -0,0 +1,575 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// In and Din = [N, C, Di, Hi, Wi] +// Out and Dout = [N, C, Do, Ho, Wo] +// Out = AvgPoolFwd(In) +// Din = AvgPoolBwd(Dout) +// Pooling dimension = D, H, W +template +struct DeviceAvgPool3dBwd_NDHWC_NDHWC : public DeviceAvgPoolBwd<3, + DOutDataType, + DInDataType, + tensor_layout::convolution::NDHWC, + tensor_layout::convolution::NDHWC> +{ + static constexpr ck::index_t NDimSpatial = 3; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto + Make3DGridDescriptor_Out_M_K_In_M(const std::vector& dout_n_c_wos_lengths, + const std::vector& din_n_c_wos_length, + const std::vector& dout_n_c_wos_strides, + const std::vector& din_n_c_wos_strides, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations, + const std::vector& input_left_pads, + const std::vector& input_right_pads, + const std::vector& tildes) + { + index_t i_ztilde = tildes[0]; + index_t i_ytilde = tildes[1]; + index_t i_xtilde = tildes[2]; + + const index_t N = dout_n_c_wos_lengths[0]; + const index_t C = dout_n_c_wos_lengths[1]; + + const index_t Di = din_n_c_wos_length[2]; + const index_t Hi = din_n_c_wos_length[3]; + const index_t Wi = din_n_c_wos_length[4]; + + const index_t Do = dout_n_c_wos_lengths[2]; + const index_t Ho = dout_n_c_wos_lengths[3]; + const index_t Wo = dout_n_c_wos_lengths[4]; + + const index_t Z = window_lengths[0]; + const index_t Y = window_lengths[1]; + const index_t X = window_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = window_strides[0]; + const index_t ConvStrideH = window_strides[1]; + const index_t ConvStrideW = window_strides[2]; + + const index_t ConvDilationD = window_dilations[0]; + const index_t ConvDilationH = window_dilations[1]; + const index_t ConvDilationW = window_dilations[2]; + + const auto out_n_do_ho_wo_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C), + make_tuple(dout_n_c_wos_strides[0], + dout_n_c_wos_strides[2], + dout_n_c_wos_strides[3], + dout_n_c_wos_strides[4], + dout_n_c_wos_strides[1])); + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on Tildes that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = + math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = + math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = + math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // ReduceK is different for each Reduce + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // Problem size of reduction kernel + const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // Out[ReduceM, ReduceK] + const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)), + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor( + out_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // In[ReduceM] + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C), + make_tuple(din_n_c_wos_strides[0], + din_n_c_wos_strides[2], + din_n_c_wos_strides[3], + din_n_c_wos_strides[4], + din_n_c_wos_strides[1])); + + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_grid_desc_reducemraw = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto in_grid_desc_reducem = + transform_tensor_descriptor(in_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem); + } + + using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0}, + {0, 0, 0})); + + using DoutGridDesc_M_K = remove_cvref_t>; + using DinGridDesc_M = remove_cvref_t>; + + // FIXME + // for NDHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K + + using PassThrough = tensor_operation::element_wise::PassThrough; + using Div = tensor_operation::element_wise::UnaryDivide; + + using gridwise_reduce = GridwiseReduction_mk_to_m_threadwise; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + DInDataType* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_dout_grid_{p_dout}, + p_din_grid_{p_din}, + dout_n_c_wos_lengths_{dout_n_c_wos_lengths}, + din_n_c_wos_length_{din_n_c_wos_length}, + dout_n_c_wos_strides_{dout_n_c_wos_strides}, + din_n_c_wos_strides_{din_n_c_wos_strides}, + num_reduce_{1}, + div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]} + { + std::vector Tildes(NDimSpatial); + for(int i = 0; i < NDimSpatial; ++i) + { + int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]); + Tildes[i] = window_strides[i] / GcdStrideDilation; + num_reduce_ *= Tildes[i]; + } + + for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]); + const auto YDotSlice = + math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]); + const auto XDotSlice = + math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]); + + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto dout_din_grid_desc = + Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads, + {i_ztilde, i_ytilde, i_xtilde}); + + dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]); + din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]); + } + } + } + } + + const DOutDataType* p_dout_grid_; + DInDataType* p_din_grid_; + std::vector dout_n_c_wos_lengths_; + std::vector din_n_c_wos_length_; + std::vector dout_n_c_wos_strides_; + std::vector din_n_c_wos_strides_; + + int num_reduce_; + std::vector dout_grid_desc_m_k_container_; + std::vector din_grid_desc_m_container_; + + Div div_element_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + for(index_t i = 0; i < arg.num_reduce_; i++) + { + const auto kernel = kernel_reduce_threadwise; + + ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0); + const index_t grid_size = (M / M_BlockTileSize); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.dout_grid_desc_m_k_container_[i], + arg.din_grid_desc_m_container_[i], + PassThrough{}, + arg.div_element_op_, + float(1), + arg.p_dout_grid_, + nullptr, + float(0), + arg.p_din_grid_, + nullptr); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr index_t Rank = NDimSpatial + 2; + int doutFastestDim = -1; + int dinFastestDim = -1; + + for(int i = 0; i < Rank; ++i) + { + if(arg.dout_n_c_wos_strides_[i] == 1) + doutFastestDim = i; + if(arg.din_n_c_wos_strides_[i] == 1) + dinFastestDim = i; + } + + if(doutFastestDim == -1 || dinFastestDim == -1) + { + if constexpr(InSrcOutDstVectorSize != 1) + return false; + } + else + { + if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0) + return false; + if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + void* p_din, + std::vector dout_n_c_wos_lengths, + std::vector din_n_c_wos_length, + std::vector dout_n_c_wos_strides, + std::vector din_n_c_wos_strides, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations, + std::vector input_left_pads, + std::vector input_right_pads) override + { + constexpr index_t Rank = NDimSpatial + 2; + + if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank || + dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial || + window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial || + input_right_pads.size() != NDimSpatial) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(static_cast(p_dout), + static_cast(p_din), + dout_n_c_wos_lengths, + din_n_c_wos_length, + dout_n_c_wos_strides, + din_n_c_wos_strides, + window_lengths, + window_strides, + window_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceAvgPool3dBwd<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp new file mode 100755 index 000000000000..ab3f3856aaab --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,1071 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 +// +// Detail- Packed tensor satisfies +// stride_0 = 1 +// stride_i = stride_{i - 1} * extent_{i - 1} +// So tensor +// [G0, G1, G2, M, N] +// transposed into tensor +// [G0, G2, G1, M, N] +// with strides +// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1] +// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some +// strides from input tensor extents so finer dimension information is lost. Merging dimensions is +// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. +// +// Might need to expose dimension order to the interface to fully support +// TensorSpecialization::Packed in a traditional sense of "packed" tensor +template +struct DeviceBatchedContractionMultipleD_Wmma_CShuffle + : public DeviceBatchedContractionMultipleD +{ + using DeviceOp = DeviceBatchedContractionMultipleD_Wmma_CShuffle; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1)) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1)) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + const auto b_grid_desc_n_k = [&]() { + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + // Gridwise descriptor, mapping to whole given provblem. + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_A_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_B_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + using AGridDesc = decltype(DeviceOp::MakeAGridDescriptor({}, {})); + using BGridDesc = decltype(DeviceOp::MakeBGridDescriptor({}, {})); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_Wmma< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc, + BGridDesc, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& b_gs_ns_ks_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_strides, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_{}, + b_grid_desc_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{}, + a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]}, + b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]}, + compute_ptr_offset_of_batch_{ + a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + }); + + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(b_gs_ns_ks_lengths, b_gs_ns_ks_strides); + + ds_grid_desc_m_n_ = + DeviceOp::MakeDsGridDescriptor_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides); + + e_grid_desc_m_n_ = + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); + + ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + + // for sanity check of vector memory access + a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; + a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]; + b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1]; + b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]; + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // Idle + index_t M01_; + index_t N01_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + + index_t a_batch_stride_; + index_t b_batch_stride_; + + // Batch Offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // for checking vector load/store + // index_t MRaw_; + // index_t NRaw_; + // index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + typename GridwiseOp::DefaultBlock2CTileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_grid_desc_, + arg.b_grid_desc_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Arch check failure\n"); + return false; + } + } + else + { + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + printf("GridwiseOp: Validity check failure\n"); + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + printf("DeviceOp: Vector Access A-m check failure\n"); + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1)) + { + index_t LastK = + AEnableLds ? arg.a_grid_desc_.GetLength(I2) : arg.a_grid_desc_.GetLength(I6); + if(LastK % ABlockTransferSrcScalarPerVector == 0) + { + printf("DeviceOp: Vector Access A-k check failure\n"); + return false; + } + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + printf("DeviceOp: Vector Access B-n check failure\n"); + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + printf("DeviceOp: Vector Access B-k check failure\n"); + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetLength(I3) % + CDEShuffleBlockTransferScalarPerVector_NPerBlock == + 0)) + { + printf("DeviceOp: Vector Access D-n check failure\n"); + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!((arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I3) % + CDEShuffleBlockTransferScalarPerVector_NPerBlock == + 0) || + CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) + { + printf("DeviceOp: Vector Access E-n check failure\n"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ks_lengths, + b_gs_ns_ks_lengths, + ds_gs_ms_ns_lengths, + e_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_strides, + ds_gs_ms_ns_strides, + e_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ks_lengths, + b_gs_ns_ks_lengths, + ds_gs_ms_ns_lengths, + e_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_strides, + ds_gs_ms_ns_strides, + e_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedContractionMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..fc1a2b995ae9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1048 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + FloatDsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] + +// NOTE: TensorSpecialization::Packed specialized tensor is "packed" in a sense that each inner +// dimension in a dimension group (eg [G0, G1] in Gs, [M0, M1, M2] in Ms, etc.) are contiguous and +// ordered. Not in a sense that the tensor [G0, G1, ..., M0, M1, ..., N0, N1...] can be permuted +// while still being a contiguous, unpadded tensor. In other words, it merely degenerates into +// TensorSpecialization::Default with NumDimG/M/N/K = 1 +// +// Detail- Packed tensor satisfies +// stride_0 = 1 +// stride_i = stride_{i - 1} * extent_{i - 1} +// So tensor +// [G0, G1, G2, M, N] +// transposed into tensor +// [G0, G2, G1, M, N] +// with strides +// [G2 * G1 * M * N, G1 * M * N, M * N, N, 1] +// is again a packed tensor. MakeGridDescriptor() currently just merges dimensions and ignores some +// strides from input tensor extents so finer dimension information is lost. Merging dimensions is +// essentially a degenerated case of TensorSpecialization::Default with NumDimG/M/N/K = 1. +// +// Might need to expose dimension order to the interface to fully support +// TensorSpecialization::Packed in a traditional sense of "packed" tensor +template +struct DeviceBatchedContractionMultipleD_Xdl_CShuffle + : public DeviceBatchedContractionMultipleD +{ + using DeviceOp = DeviceBatchedContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_A_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * batch_stride_B_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + using ComputeDataType = ADataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{}, + a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]}, + b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]}, + compute_ptr_offset_of_batch_{ + a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_} + { + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, ""); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], + ds_gs_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; + a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]; + b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1]; + b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]; + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + + index_t a_batch_stride_; + index_t b_batch_stride_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!((arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimG << ", " + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp new file mode 100755 index 000000000000..0cd1d84a4327 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -0,0 +1,696 @@ +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_e_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for +\link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemmCPermute and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_e_permute_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>{}, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_etile_map; +#endif +} + +template +struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute +{ + using DeviceOp = DeviceBatchedGemmEPermuteXdl; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N) + { + const auto e_grid_desc_mraw_nraw = + make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), make_tuple(stride_M, stride_N)); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, + index_t G1, + index_t MRaw, + index_t NRaw, + index_t stride_G0, + index_t stride_G1, + index_t stride_M, + index_t stride_N) + { + const auto e_grid_desc_g0_g1_mraw_nraw = [&]() { + return make_naive_tensor_descriptor( + make_tuple(G0, G1, MRaw, NRaw), + make_tuple(stride_G0, stride_G1, stride_M, stride_N)); + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + e_grid_desc_g0_g1_mraw_nraw, + make_tuple(make_pass_through_transform(G0), + make_pass_through_transform(G1), + make_pass_through_transform(MRaw), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + } + else + { + // not pad M or N + return e_grid_desc_g0_g1_mraw_nraw; + } + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1)); + using EGridDesc_G0_G1_M_N = decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, + index_t Batchstride_B, + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n) + : Batchstride_A_(Batchstride_A), + Batchstride_B_(Batchstride_B), + e_grid_desc_g0_g1_m_n_(e_grid_desc_g0_g1_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_A_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(Batchstride_B_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + const index_t G1 = e_grid_desc_g0_g1_m_n_.GetLength(I1); + index_t b0 = g_idx / G1; + index_t b1 = g_idx - b0 * G1; // g_idx % G1 + return e_grid_desc_g0_g1_m_n_.CalculateOffset(make_multi_index(b0, b1, 0, 0)); + } + + private: + index_t Batchstride_A_; + index_t Batchstride_B_; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + }; + + using ComputeDataType = ADataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, // DsDataType, + EDataType, // EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + AGridDesc_M_K, + BGridDesc_N_K, + Tuple<>, + EGridDesc_M_N, + NumPrefetch, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_e_grid_{p_e_grid}, + BatchCount_(BatchCount), + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(M, K, stride_A)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(K, N, stride_B)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(batched_gemm_e_permute_desc.M_, + batched_gemm_e_permute_desc.N_, + batched_gemm_e_permute_desc.stride_M_, + batched_gemm_e_permute_desc.stride_N_)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_g0_g1_m_n_{ + DeviceOp::MakeEGridDescriptor_G0_G1_M_N(batched_gemm_e_permute_desc.G0_, + batched_gemm_e_permute_desc.G1_, + batched_gemm_e_permute_desc.M_, + batched_gemm_e_permute_desc.N_, + batched_gemm_e_permute_desc.stride_G0_, + batched_gemm_e_permute_desc.stride_G1_, + batched_gemm_e_permute_desc.stride_M_, + batched_gemm_e_permute_desc.stride_N_)}, + compute_ptr_offset_of_batch_{batch_stride_A, batch_stride_B, e_grid_desc_g0_g1_m_n_}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ck::Tuple<>{}, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + std::cout << "C[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // batch count + index_t BatchCount_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock; + EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_; + + // for calculating Batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + ck::Tuple<>{}, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseBatchedGemmCPermute_km_kn_m0m1n0n1_xdlops_v2r3 has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.BatchCount_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_e_permute_xdl< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + EDataType, + remove_reference_t, + remove_reference_t, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + remove_reference_t, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_e_grid_, + arg.BatchCount_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + ck::Tuple<>{}, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + EDataType* p_e, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_e, + M, + N, + K, + stride_A, + stride_B, + batch_stride_A, + batch_stride_B, + batched_gemm_e_permute_desc, + BatchCount, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t stride_A, + index_t stride_B, + index_t batch_stride_A, + index_t batch_stride_B, + BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, + index_t BatchCount, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_e), + M, + N, + K, + stride_A, + stride_B, + batch_stride_A, + batch_stride_B, + batched_gemm_e_permute_desc, + BatchCount, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmEPermuteXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp new file mode 100755 index 000000000000..985752796bfa --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp @@ -0,0 +1,747 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; +#endif // end of if (defined(__gfx9__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmGemm_Xdl_CShuffle : public DeviceBatchedGemmGemm +{ + using DeviceOp = DeviceBatchedGemmGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 + static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideB1_; + index_t BatchStrideC_; + }; + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, // = ORaw + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + b1_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + batch_count_(Batch), + compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}, + raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + b1_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B0[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "B1[BK0, N, BK1]: " << b1_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl; + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + + // For robust IsSupportedArgument() check + std::vector raw_lengths_m_n_k_o_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error("wrong! unsupported argument"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_gemm_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + ComputeBasePtrOfStridedBatch, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; + const auto NRaw = arg.raw_lengths_m_n_k_o_[1]; + const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; + const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = + is_same_v ? KRaw : MRaw; + const auto b_extent_lowest = + is_same_v ? NRaw : KRaw; + const auto b1_extent_lowest = + is_same_v ? Gemm1NRaw : NRaw; + const auto c_extent_lowest = + is_same_v ? Gemm1NRaw : MRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_b1, p_c, MRaw, + NRaw, KRaw, Gemm1NRaw, Batch, StrideA, + StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB, + BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op, + b1_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_b1), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + Gemm1NRaw, + Batch, + StrideA, + StrideB, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideB1, + BatchStrideC, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp new file mode 100755 index 000000000000..12085edaae5d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -0,0 +1,724 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdl(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ + +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_etile_map; +#endif +} + +template +struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD +{ + using DeviceOp = DeviceBatchedGemmMultiD_Xdl; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = g_idx * static_cast(BatchStrideDs_[i]); + }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideE_; + }; + + using ComputeDataType = ADataType; + + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + Batch_(Batch), + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Batch + index_t Batch_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for calculating batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmMultiD_Xdl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.Batch_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = + kernel_batched_gemm_xdl; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultiD_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp new file mode 100755 index 000000000000..1b487502f4f9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp @@ -0,0 +1,792 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + */ + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = e_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + +#endif +} + +template && + is_same_v, + bool> = false> +struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD + +{ + using DeviceOp = DeviceBatchedGemmMultipleD_Dl; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = g_idx * static_cast(BatchStrideDs_[i]); + }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideE_; + }; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using EGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + K_(K), + Batch_(Batch), + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + e_grid_desc_m0_m10_m11_n0_n10_n11_{}, + compute_ptr_offset_of_batch_{BatchStrideA, BatchStrideB, BatchStrideDs, BatchStrideE}, + block_2_ctile_map_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceBatchedGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = + DeviceBatchedGemmMultipleD_Dl::MakeEGridDescriptor_M_N(M, N, StrideE); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + + ds_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_); + + e_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + index_t K_; + + // Batch + index_t Batch_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // for calculating batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmMultipleD_Dl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0), + arg.e_grid_desc_m_n_.GetLength(I1)) * + arg.Batch_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = + kernel_gemm_dl_multiple_d; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.e_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_ctile_map_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + bool pass = true; + pass = pass && arg.K_ % K1 == 0; + + pass = pass && GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.e_grid_desc_m_n_); + + return pass; + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + Batch, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultipleD_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..d38698af4b30 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,996 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_gemm_xdl_cshuffle_v1( + const A0B0B1DataType* __restrict__ p_a0_grid, + const A0B0B1DataType* __restrict__ p_b0_grid, + D0sPointer p_d0s_grid, + const A0B0B1DataType* __restrict__ p_b1_grid, + D1sPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, + const A0ElementwiseOperation a0_element_op, + const B0ElementwiseOperation b0_element_op, + const CDE0ElementwiseOperation cde0_element_op, + const B1ElementwiseOperation b1_element_op, + const CDE1ElementwiseOperation cde1_element_op, + const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, + const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2E1TileMap block_2_e1tile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); + + static_for<0, p_d1s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD1BasePtr(g_idx, In))); + p_d1s_grid(In) = p_d1s_grid(In) + d1_batch_offset; + }); + + GridwiseGemm::template Run(p_a0_grid + a_batch_offset, + p_b0_grid + b_batch_offset, + p_d0s_grid, + p_b1_grid + b1_batch_offset, + p_d1s_grid, + p_e1_grid + c_batch_offset, + p_shared, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op, + a0_grid_desc_ak0_m_ak1, + b0_grid_desc_bk0_n_bk1, + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + b1_grid_desc_bk0_n_bk1, + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + e1_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_e1tile_map); +#else + ignore = p_a0_grid; + ignore = p_b0_grid; + ignore = p_d0s_grid; + ignore = p_b1_grid; + ignore = p_d1s_grid; + ignore = p_e1_grid; + ignore = a0_element_op; + ignore = b0_element_op; + ignore = cde0_element_op; + ignore = b1_element_op; + ignore = cde1_element_op; + ignore = a0_grid_desc_ak0_m_ak1; + ignore = b0_grid_desc_bk0_n_bk1; + ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = d1s_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_e1tile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; +#endif +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle + : public DeviceBatchedGemmMultipleDGemmMultipleD +{ + using DeviceOp = DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle; + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + static constexpr auto gemm0_padder = + GemmPadder_v2{ + Gemm0MPerBlock, Gemm0NPerBlock, Gemm0KPerBlock}; + + static constexpr auto gemm1_padder = + GemmPadder_v2{ + Gemm0MPerBlock, Gemm1NPerBlock, Gemm1KPerBlock}; + + // for Gemm0 + static auto MakeA0GridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA0) + { + const auto a0_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA0, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA0)); + } + }(); + + return gemm0_padder.PadADescriptor_M_K(a0_grid_desc_mraw_kraw); + } + + // for Gemm0 + static auto MakeB0GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b0_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return gemm0_padder.PadBDescriptor_N_K(b0_grid_desc_nraw_kraw); + } + + // for Gemm0 + template + static auto MakeD0GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideD0) + { + const auto d0_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideD0, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideD0)); + } + }(); + + return gemm0_padder.PadCDescriptor_M_N(d0_grid_desc_mraw_nraw); + } + + // for Gemm1 + static auto MakeB1GridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return gemm1_padder.PadBDescriptor_N_K(b1_grid_desc_nraw_kraw); + } + + // for Gemm1 + template + static auto MakeE1GridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE1) + { + const auto e1_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE1, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE1)); + } + }(); + + return gemm1_padder.PadCDescriptor_M_N(e1_grid_desc_mraw_nraw); + } + + static auto MakeD0sGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeD0GridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + static auto MakeD1sGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeE1GridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1) + : BatchStrideA0_(BatchStrideA0), + BatchStrideB0_(BatchStrideB0), + BatchStrideD0s_(BatchStrideD0s), + BatchStrideB1_(BatchStrideB1), + BatchStrideD1s_(BatchStrideD1s), + BatchStrideE1_(BatchStrideE1) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA0_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB0_); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD0s_[d1_idx]); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE1_); + } + + template + __host__ __device__ constexpr auto GetD1BasePtr(index_t g_idx, Number d1_idx) const + { + return g_idx * static_cast(BatchStrideD1s_[d1_idx]); + } + + private: + index_t BatchStrideA0_; + index_t BatchStrideB0_; + std::array BatchStrideD0s_; + index_t BatchStrideB1_; + std::array BatchStrideD1s_; + index_t BatchStrideE1_; + }; + + using A0GridDesc_M_K = decltype(MakeA0GridDescriptor_M_K(1, 1, 1)); + using B0GridDesc_N_K = decltype(MakeB0GridDescriptor_N_K(1, 1, 1)); + using D0sGridDesc_M_N = remove_cvref_t; + using B1GridDesc_N_K = decltype(MakeB1GridDescriptor_N_K(1, 1, 1)); + using D1sGridDesc_M_N = remove_cvref_t; + using E1GridDesc_M_N = decltype(MakeE1GridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle< + A0DataType, // TODO: distinguish A/B datatype + Acc0DataType, + D0sDataType, + Acc1DataType, + C1ShuffleDataType, + D1sDataType, + E1DataType, + A0ElementwiseOperation, + B0ElementwiseOperation, + CDE0ElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + InMemoryDataOperationEnum::Set, + A0GridDesc_M_K, + B0GridDesc_N_K, + D0sGridDesc_M_N, + B1GridDesc_N_K, + D1sGridDesc_M_N, + E1GridDesc_M_N, + NumGemm0KPrefetchStage, + BlockSize, + Gemm0MPerBlock, + Gemm0NPerBlock, + Gemm0KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + A0K1, + B0K1, + B1K1, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm0NXdlPerWave, + Gemm1NXdlPerWave, + A0BlockTransferThreadClusterLengths_AK0_M_AK1, + A0BlockTransferThreadClusterArrangeOrder, + A0BlockTransferSrcAccessOrder, + A0BlockTransferSrcVectorDim, + A0BlockTransferSrcScalarPerVector, + A0BlockTransferDstScalarPerVector_AK1, + true, + A0BlockLdsExtraM, + B0BlockTransferThreadClusterLengths_BK0_N_BK1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_BK1, + true, + B0BlockLdsExtraN, + CDE0BlockTransferSrcVectorDim, + CDE0BlockTransferSrcScalaerPerVector, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + C1ShuffleMXdlPerWavePerShuffle, + C1ShuffleGemm0NXdlPerWavePerShuffle, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using A0GridDesc_AK0_M_AK1 = + remove_cvref_t; + using B0GridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const A0DataType* p_a0_grid, + const B0DataType* p_b0_grid, + std::array p_d0s_grid, + const B1DataType* p_b1_grid, + std::array p_d1s_grid, + E1DataType* p_e1_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, // = ORaw + index_t Batch, + index_t StrideA0, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) + : p_a0_grid_{p_a0_grid}, + p_b0_grid_{p_b0_grid}, + p_d0s_grid_{}, + p_b1_grid_{p_b1_grid}, + p_d1s_grid_{}, + p_e1_grid_{p_e1_grid}, + a0_grid_desc_m_k_{DeviceOp::MakeA0GridDescriptor_M_K(MRaw, KRaw, StrideA0)}, + b0_grid_desc_n_k_{DeviceOp::MakeB0GridDescriptor_N_K(KRaw, NRaw, StrideB0)}, + d0s_grid_desc_m_n_{}, + b1_grid_desc_n_k_{DeviceOp::MakeB1GridDescriptor_N_K(NRaw, Gemm1NRaw, StrideB1)}, + d1s_grid_desc_m_n_{}, + e1_grid_desc_m_n_{ + DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideE1)}, + a0_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultA0GridDescriptor_AK0_M_AK1(a0_grid_desc_m_k_)}, + b0_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultB0GridDescriptor_BK0_N_BK1(b0_grid_desc_n_k_)}, + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, + b1_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultB1GridDescriptor_BK0_N_BK1(b1_grid_desc_n_k_)}, + d1s_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_e1tile_map_{GridwiseGemm::MakeDefaultBlock2E1TileMap(e1_grid_desc_m_n_)}, + a0_element_op_{a0_element_op}, + b0_element_op_{b0_element_op}, + cde0_element_op_{cde0_element_op}, + b1_element_op_{b1_element_op}, + cde1_element_op_{cde1_element_op}, + batch_count_(Batch), + compute_base_ptr_of_batch_{BatchStrideA0, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1} + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "a0_grid_desc_m_k_{" << a0_grid_desc_m_k_.GetLength(I0) << ", " + << a0_grid_desc_m_k_.GetLength(I1) << "}" << std::endl; + std::cout << "b0_grid_desc_n_k_{" << b0_grid_desc_n_k_.GetLength(I0) << ", " + << b0_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; + std::cout << "d0s_grid_desc_m_n_[I0]{" << d0s_grid_desc_m_n_[I0].GetLength(I0) + << ", " << d0s_grid_desc_m_n_[I0].GetLength(I1) << "}" << std::endl; + std::cout << "b1_grid_desc_n_k_{" << b1_grid_desc_n_k_.GetLength(I0) << ", " + << b1_grid_desc_n_k_.GetLength(I1) << "}" << std::endl; + std::cout << "d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{" + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I0) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I1) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I2) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I3) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I4) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I5) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I6) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I7) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I8) << ", " + << d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_[I0].GetLength(I9) << "}" + << std::endl; + std::cout << "e1_grid_desc_m_n_{" << e1_grid_desc_m_n_.GetLength(I0) << ", " + << e1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0Layout = remove_cvref_t>; + using D0DataType = remove_cvref_t>; + + // D0 pointer + p_d0s_grid_(i) = static_cast(p_d0s_grid[i]); + + // D0 desc + d0s_grid_desc_m_n_(i) = + DeviceOp::MakeD0GridDescriptor_M_N(MRaw, NRaw, StrideD0s[i]); + }); + + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + using D1Layout = remove_cvref_t>; + using D1DataType = remove_cvref_t>; + + // D1 pointer + p_d1s_grid_(i) = static_cast(p_d1s_grid[i]); + + // D1 desc + d1s_grid_desc_m_n_(i) = + DeviceOp::MakeE1GridDescriptor_M_N(MRaw, Gemm1NRaw, StrideD1s[i]); + }); + + if(GridwiseGemm::CheckValidity(a0_grid_desc_m_k_, + b0_grid_desc_n_k_, + b1_grid_desc_n_k_, + e1_grid_desc_m_n_, + block_2_e1tile_map_)) + { + e1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e1_grid_desc_m_n_); + + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = + GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( + d0s_grid_desc_m_n_); + + d1s_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d1s_grid_desc_m_n_); + } + } + + // private: + // pointers + const A0DataType* p_a0_grid_; + const B0DataType* p_b0_grid_; + typename GridwiseGemm::D0sGridPointer p_d0s_grid_; + const B1DataType* p_b1_grid_; + typename GridwiseGemm::D1sGridPointer p_d1s_grid_; + E1DataType* p_e1_grid_; + + // tensor descriptors for problem definiton + A0GridDesc_M_K a0_grid_desc_m_k_; + B0GridDesc_N_K b0_grid_desc_n_k_; + D0sGridDesc_M_N d0s_grid_desc_m_n_; + B1GridDesc_N_K b1_grid_desc_n_k_; + D1sGridDesc_M_N d1s_grid_desc_m_n_; + E1GridDesc_M_N e1_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1_; + B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + d1s_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e1_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e1-tile map + typename GridwiseGemm::DefaultBlock2E1TileMap block_2_e1tile_map_; + + // element-wise op + A0ElementwiseOperation a0_element_op_; + B0ElementwiseOperation b0_element_op_; + CDE0ElementwiseOperation cde0_element_op_; + B1ElementwiseOperation b1_element_op_; + CDE1ElementwiseOperation cde1_element_op_; + + // batch + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_e1tile_map_.CalculateGridSize(arg.e1_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = arg.a0_grid_desc_m_k_.GetLength(I1); + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_gemm_xdl_cshuffle_v1< + GridwiseGemm, + A0DataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::D0sGridPointer, + typename GridwiseGemm::D1sGridPointer, + E1DataType, + A0ElementwiseOperation, + B0ElementwiseOperation, + CDE0ElementwiseOperation, + B1ElementwiseOperation, + CDE1ElementwiseOperation, + DeviceOp::A0GridDesc_AK0_M_AK1, + DeviceOp::B0GridDesc_BK0_N_BK1, + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2E1TileMap, + ComputeBasePtrOfStridedBatch, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a0_grid_, + arg.p_b0_grid_, + arg.p_d0s_grid_, + arg.p_b1_grid_, + arg.p_d1s_grid_, + arg.p_e1_grid_, + arg.a0_element_op_, + arg.b0_element_op_, + arg.cde0_element_op_, + arg.b1_element_op_, + arg.cde1_element_op_, + arg.a0_grid_desc_ak0_m_ak1_, + arg.b0_grid_desc_bk0_n_bk1_, + arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.d1s_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_e1tile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + // check if DsLayout is supported + template + static bool CheckDLayout() + { + static bool valid = true; + // iterate over DLayout tuple + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + // if RefLayout and DLayout are same, keep valid true, otherwise false + valid = valid && is_same_v; + }); + return valid; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // Check supported layouts + // A0 - Row + // B0 - Col + // D0s - Rows + // B1 - Row or Col + // D1s - Rows + // E1 - Row + if(!(is_same_v && + is_same_v && + CheckDLayout() && + (is_same_v || + is_same_v)&&CheckDLayout() && + is_same_v)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a0_grid_desc_m_k_, + arg.b0_grid_desc_n_k_, + arg.b1_grid_desc_n_k_, + arg.e1_grid_desc_m_n_, + arg.block_2_e1tile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const A0DataType* p_a0, + const B0DataType* p_b0, + std::array p_d0s, + const B1DataType* p_b1, + std::array p_d1s, + E1DataType* p_e1, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA0, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) + { + return Argument{p_a0, p_b0, + p_d0s, p_b1, + p_d1s, p_e1, + MRaw, NRaw, + KRaw, Gemm1NRaw, + Batch, StrideA0, + StrideB0, StrideD0s, + StrideB1, StrideD1s, + StrideE1, BatchStrideA0, + BatchStrideB0, BatchStrideD0s, + BatchStrideB1, BatchStrideD1s, + BatchStrideE1, a0_element_op, + b0_element_op, cde0_element_op, + b1_element_op, cde1_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a0, + const void* p_b0, + std::array p_d0s, + const void* p_b1, + std::array p_d1s, + void* p_e1, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA0, + index_t StrideB0, + std::array StrideD0s, + index_t StrideB1, + std::array StrideD1s, + index_t StrideE1, + index_t BatchStrideA0, + index_t BatchStrideB0, + std::array BatchStrideD0s, + index_t BatchStrideB1, + std::array BatchStrideD1s, + index_t BatchStrideE1, + A0ElementwiseOperation a0_element_op, + B0ElementwiseOperation b0_element_op, + CDE0ElementwiseOperation cde0_element_op, + B1ElementwiseOperation b1_element_op, + CDE1ElementwiseOperation cde1_element_op) override + { + return std::make_unique(static_cast(p_a0), + static_cast(p_b0), + p_d0s, + static_cast(p_b1), + p_d1s, + static_cast(p_e1), + MRaw, + NRaw, + KRaw, + Gemm1NRaw, + Batch, + StrideA0, + StrideB0, + StrideD0s, + StrideB1, + StrideD1s, + StrideE1, + BatchStrideA0, + BatchStrideB0, + BatchStrideD0s, + BatchStrideB1, + BatchStrideD1s, + BatchStrideE1, + a0_element_op, + b0_element_op, + cde0_element_op, + b1_element_op, + cde1_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << Gemm0MPerBlock << ", " + << Gemm0NPerBlock << ", " + << Gemm0KPerBlock << ", " + << A0K1 << ", " + << B0K1 << ", " + << B1K1 << ", " + << Gemm0MPerXdl << ", " + << Gemm0NPerXdl << ", " + << Gemm0MXdlPerWave << ", " + << Gemm0NXdlPerWave << ", " + << Gemm1NXdlPerWave << "> "; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp new file mode 100755 index 000000000000..6624570b277e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,1051 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_xdl_cshuffle_v3_multi_d(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); + + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + // populate pointer, desc for Ds + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + // D pointer + karg.p_ds_grid(i) = karg.p_ds_grid(i) + ds_batch_offset[i]; + }); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid + c_batch_offset, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 + : public DeviceBatchedGemmV2MultiD +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using CDEShuffleBlockTransferScalarPerVectors_ = CDEShuffleBlockTransferScalarPerVectors; + using CDataType_ = CDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch() = default; + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + std::array BatchStrideDs, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideA_) * g_idx; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideB_) * g_idx; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset_; + + static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { + ds_offset_[i] = static_cast(BatchStrideDs_[i]) * g_idx; + }); + + return ds_offset_; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(BatchStrideC_) * g_idx; + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + std::array BatchStrideDs_; + index_t BatchStrideC_; + }; + + struct Argument : public GridwiseGemm::Argument + { + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + + Argument() = default; + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_e_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideE_, + index_t BatchStrideA_, + index_t BatchStrideB_, + const std::array& BatchStrideDs_, + index_t BatchStrideE_, + index_t Batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + index_t KBatch_) + : GridwiseGemm::Argument{p_a_grid_, + p_b_grid_, + p_ds_grid_, + p_e_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideE_, + KBatch_, + a_element_op_, + b_element_op_, + c_element_op_}, + Batch{Batch_}, + compute_ptr_offset_of_batch{ + BatchStrideA_, BatchStrideB_, BatchStrideDs_, BatchStrideE_} + { + } + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) * arg.Batch; + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) * arg.Batch; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString( + hipMemsetAsync(arg_.p_c_grid, + 0, + arg.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + const auto clear_workspace = [&]() { + if(arg.KBatch > 1) + hipGetErrorString( + hipMemsetAsync(arg.p_c_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch = 1) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + c_element_op, + KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t Batch, + index_t StrideA, + index_t StrideB, + const std::array& StrideDs, + index_t StrideE, + index_t BatchStrideA, + index_t BatchStrideB, + const std::array& BatchStrideDs, + index_t BatchStrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_e), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + BatchStrideA, + BatchStrideB, + BatchStrideDs, + BatchStrideE, + Batch, + a_element_op, + b_element_op, + c_element_op, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + const long_index_t d_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch_.GetDBasePtr(g_idx, In))); + p_reduces_grid(In) = p_reduces_grid(In) + d_batch_offset; + }); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_reduces_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = compute_base_ptr_of_batch_; + ignore = block_2_ctile_map; +#endif +} + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceBatchedGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideD) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideD_(BatchStrideD) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + template + __host__ __device__ constexpr long_index_t GetDBasePtr(index_t g_idx, + Number reduction_idx) const + { + // TODO - Support sequence of StrideD in MakeArgument() + (void)reduction_idx; + return g_idx * static_cast(BatchStrideD_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideD_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops, + index_t Batch) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + Batch_(Batch), + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + compute_base_ptr_of_batch_{ + type_convert(a_grid_desc_ak0_m_ak1_.GetElementSpaceSize()), + type_convert(b_grid_desc_bk0_n_bk1_.GetElementSpaceSize()), + type_convert(c_grid_desc_m_n_.GetElementSpaceSize()), + type_convert(reduce_grid_desc_m_.GetElementSpaceSize())}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + index_t Batch_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + { + std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl; + + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.reduce_grid_desc_m_{ " + << arg.reduce_grid_desc_m_.GetLength(I0) << "}" << std::endl; + } + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Batch_; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_batched_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + ComputeBasePtrOfStridedBatch, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.Batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.compute_base_ptr_of_batch_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + auto casted_p_arg = dynamic_cast(p_arg); + if(casted_p_arg == nullptr) + { + return false; + } + else + { + return IsSupportedArgument(*casted_p_arg); + } + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + Batch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t Batch = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + Batch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp new file mode 100755 index 000000000000..1026118381bc --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp @@ -0,0 +1,1730 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_wmma_cshuffle(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * G1 * K, K, G1 * K, 1} // A layout [G0, M, G1, K] + : std::array{G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * G1 * K, K, G1 * K, 1} // B0 layout [G0, N, G1, K] + : std::array{G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * G1 * O, O, 1, G1 * O} // B1 layout [G0, N, G1, O] + : std::array{G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * G1 * O, O, G1 * O, 1} // C layout [G0, M, G1, O] + : std::array{G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = alpha; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_self_attention_forward(const QKVDataType* __restrict__ p_qkv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array qk_gs_ms_ks_lengths{batch_size, head_count, sequence_length, head_size}; + std::array qk_gs_ms_ks_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, head_count * 3 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, sequence_length}; + std::array v_gs_os_ns_strides{sequence_length * head_count * 3 * head_size, 3 * head_size, 1, head_count * 3 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, sequence_length, head_size}; + std::array c_gs_ms_os_strides{sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(qk_gs_ms_ks_lengths, qk_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t qkv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("sequence_length: %d\n", sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("qkv_gap: %d\n", qkv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_qkv_grid + 0 * qkv_gap + a_batch_offset, + p_qkv_grid + 1 * qkv_gap + b0_batch_offset, + p_qkv_grid + 2 * qkv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_qkv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Cross-Attention +// Self-Attention +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_wmma_cross_attention_forward(const QDataType* __restrict__ p_q_grid, + const KVDataType* __restrict__ p_kv_grid, + ODataType* __restrict__ p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** +// Make Tensor Descriptors +// o Self-attention(packed QKV): [batchSize, sequenceLength, headCount, 3, headSize] + constexpr index_t array_size = 4; + std::array q_gs_ms_ks_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array q_gs_ms_ks_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + std::array k_gs_ms_ks_lengths{batch_size, head_count, kv_sequence_length, head_size}; + std::array k_gs_ms_ks_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, head_count * 2 * head_size, 1}; + + std::array v_gs_os_ns_lengths{batch_size, head_count, head_size, kv_sequence_length}; + std::array v_gs_os_ns_strides{kv_sequence_length * head_count * 2 * head_size, 2 * head_size, 1, head_count * 2 * head_size}; + + std::array c_gs_ms_os_lengths{batch_size, head_count, q_sequence_length, head_size}; + std::array c_gs_ms_os_strides{q_sequence_length * head_count * head_size, head_size, head_count * head_size, 1}; + + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(q_gs_ms_ks_lengths, q_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(k_gs_ms_ks_lengths, k_gs_ms_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(v_gs_os_ns_lengths, v_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + const index_t kv_gap = __builtin_amdgcn_readfirstlane(head_size); +#ifdef CK_SELF_ATTN_DEBUG + if(get_thread_global_1d_id() == 0) + { + printf("batch_size: %d\n", batch_size); + printf("q_sequence_length: %d\n", q_sequence_length); + printf("k_sequence_length: %d\n", kv_sequence_length); + printf("head_count: %d\n", head_count); + printf("head_size: %d\n", head_size); + printf("kv_gap: %d\n", kv_gap); + printf("get_grid_size(): %d\n", get_grid_size()); + printf("batch_count: %d\n", batch_count); + printf("blockid: %d\n", get_block_1d_id()); + printf("num_blocks_per_batch: %d\n", num_blocks_per_batch); + printf("g_idx: %d\n", g_idx); + printf("a_batch_offset: %ld\n", a_batch_offset); + printf("b0_batch_offset: %ld\n", b0_batch_offset); + printf("b1_batch_offset: %ld\n", b1_batch_offset); + } +#endif + GridwiseOp::template Run(p_q_grid + a_batch_offset, + p_kv_grid + 0 * kv_gap + b0_batch_offset, + p_kv_grid + 1 * kv_gap + b1_batch_offset, + p_out_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_q_grid; + ignore = p_kv_grid; + ignore = p_out_grid; + ignore = batch_size; + ignore = q_sequence_length; + ignore = kv_sequence_length; + ignore = head_count; + ignore = head_size; + ignore = alpha; +#endif // end of if (defined(__gfx11__)) +} +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + struct SelfAttnArg : public BaseArgument + { + SelfAttnArg(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_qkv_grid_{p_qkv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + sequence_length_{sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_qkv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeSelfAttnArgument(const ADataType* p_qkv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return SelfAttnArg{ + p_qkv_grid, p_out_grid, batch_size, sequence_length, head_count, head_size, alpha}; + } + + struct CrossAttnArg : public BaseArgument + { + CrossAttnArg(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + : p_q_grid_{p_q_grid}, + p_kv_grid_{p_kv_grid}, + p_out_grid_{p_out_grid}, + batch_size_{batch_size}, + q_sequence_length_{q_sequence_length}, + kv_sequence_length_{kv_sequence_length}, + head_count_{head_count}, + head_size_{head_size}, + alpha_{alpha} + { + } + // Pointers + const ADataType* p_q_grid_; + const B0DataType* p_kv_grid_; + CDataType* p_out_grid_; + + // Raw Problem Size + index_t batch_size_; + index_t q_sequence_length_; + index_t kv_sequence_length_; + index_t head_count_; + index_t head_size_; + float alpha_; + }; + + static auto MakeCrossAttnArgument(const ADataType* p_q_grid, + const B0DataType* p_kv_grid, + CDataType* p_out_grid, + index_t batch_size, + index_t q_sequence_length, + index_t kv_sequence_length, + index_t head_count, + index_t head_size, + float alpha) + { + return CrossAttnArg{p_q_grid, + p_kv_grid, + p_out_grid, + batch_size, + q_sequence_length, + kv_sequence_length, + head_count, + head_size, + alpha}; + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + // Invoker + struct SelfAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::SelfAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_self_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_qkv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeSelfAttnInvoker() { return SelfAttnInvoker{}; } + + // Invoker + struct CrossAttnInvoker : public BaseInvoker + { + using Argument = DeviceOp::CrossAttnArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.q_sequence_length_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.head_size_, NPerBlock); + + const index_t grid_size = arg.batch_size_ * arg.head_count_ * M0 * N0; + const auto K = arg.head_size_; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_wmma_cross_attention_forward; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_q_grid_, + arg.p_kv_grid_, + arg.p_out_grid_, + arg.batch_size_, + arg.q_sequence_length_, + arg.kv_sequence_length_, + arg.head_count_, + arg.head_size_, + arg.alpha_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static auto MakeCrossAttnInvoker() { return CrossAttnInvoker{}; } + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = + kernel_batched_gemm_softmax_gemm_wmma_cshuffle; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp new file mode 100755 index 000000000000..bae5c6019d46 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -0,0 +1,955 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + D0sPointer p_d0s_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const C0DEElementwiseOperation c0de_element_op, + const B1ElementwiseOperation b1_element_op, + const C1DEElementwiseOperation c1de_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + static_for<0, p_d0s_grid.Size(), 1>{}([&](auto In) { + const long_index_t d0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetD0BasePtr(g_idx, In))); + p_d0s_grid(In) = p_d0s_grid(In) + d0_batch_offset; + }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_d0s_grid, + p_shared, + a_element_op, + b_element_op, + c0de_element_op, + b1_element_op, + c1de_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + block_2_ctile_map, + c0_matrix_mask); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = p_d0s_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c0de_element_op; + ignore = b1_element_op; + ignore = c1de_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5; + ignore = block_2_ctile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; + ignore = c0_matrix_mask; +#endif // end of if (defined(__gfx9__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumD1Tensor == 0, "Gemm1 Bias addition is unimplemented"); + +#if 0 + // TODO ANT: use alias + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimN; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimO; + static constexpr index_t NumDimGemm1K = NumDimN; +#endif + + using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< + Sequence, + Sequence, + GemmSpec, + ASpec, + BSpec, + B1Spec, + CSpec>; + + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), + Number{}); + } + + static auto + MakeB1GridDescriptor_BK0_N_BK1(const std::vector& b1_gs_gemm1ns_gemm1ks_lengths_vec, + const std::vector& b1_gs_gemm1ns_gemm1ks_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, + b1_gs_gemm1ns_gemm1ks_strides_vec), + Number{}); + } + + static auto MakeD0sGridDescriptor_M_N( + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return Transform::MakeCGridDescriptor_M_N(acc0_biases_gs_ms_ns_lengths[i], + acc0_biases_gs_ms_ns_strides[i]); + }, + Number{}); + } + + static auto MakeD0sGridDescriptor_G_M_N( + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return Transform::MakeCGridDescriptor_G_M_N(acc0_biases_gs_ms_ns_lengths[i], + acc0_biases_gs_ms_ns_strides[i]); + }, + Number{}); + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); + using C1GridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using C1GridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + using D0sGridDesc_M_N = decltype(MakeD0sGridDescriptor_M_N({}, {})); + using D0sGridDesc_G_M_N = decltype(MakeD0sGridDescriptor_G_M_N({}, {})); + + constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const BGridDesc_G_N_K& b_grid_desc_g_n_k, + const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, + const C1GridDesc_G_M_N& c1_grid_desc_g_m_n, + const D0sGridDesc_G_M_N& d0s_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b_grid_desc_g_n_k_(b_grid_desc_g_n_k), + b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), + c1_grid_desc_g_m_n_(c1_grid_desc_g_m_n), + d0s_grid_desc_g_m_n_(d0s_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c1_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + template + __host__ __device__ constexpr long_index_t GetD0BasePtr(index_t g_idx, + Number d0_idx) const + { + return d0s_grid_desc_g_m_n_[d0_idx].CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; + C1GridDesc_G_M_N c1_grid_desc_g_m_n_; + D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + D0sDataType, + AElementwiseOperation, + BElementwiseOperation, + C0DEElementwiseOperation, + B1ElementwiseOperation, + C1DEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + C1GridDesc_M_N, + D0sGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + D0sTransferSrcScalarPerVector>; + + // Argument + // FIXME: constness + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor>& acc0_biases_gs_ms_ns_strides, + const std::array, NumD1Tensor>& + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumD1Tensor>& + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + C0DEElementwiseOperation c0de_element_op, + B1ElementwiseOperation b1_element_op, + C1DEElementwiseOperation c1de_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + p_d0s_grid_{}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + b1_grid_desc_bk0_n_bk1_{DeviceOp::MakeB1GridDescriptor_BK0_N_BK1( + b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, + c1_grid_desc_m_n_{Transform::MakeCGridDescriptor_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_g_n_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + b1_grid_desc_g_n_k_{Transform::MakeB1GridDescriptor_G_N_K( + b1_gs_gemm1ns_gemm1ks_lengths, b1_gs_gemm1ns_gemm1ks_strides)}, + c1_grid_desc_g_m_n_{Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_gemm1ns_lengths, + c_gs_ms_gemm1ns_strides)}, + d0s_grid_desc_g_m_n_{DeviceOp::MakeD0sGridDescriptor_G_M_N( + acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}, + c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c1_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c0de_element_op_{c0de_element_op}, + b1_element_op_{b1_element_op}, + c1de_element_op_{c1de_element_op}, + c0_matrix_mask_{b_grid_desc_g_n_k_.GetLength(I1)}, + raw_lengths_mz_nz_kz_gemm1nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b_gs_ns_ks_lengths[NumDimG + NumDimN - 1], + b_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], + b1_gs_gemm1ns_gemm1ks_lengths[NumDimG + NumDimO - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b_nz_kz_strides_{b_gs_ns_ks_strides[NumDimG + NumDimN - 1], + b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, + b1_nz_kz_strides_{b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO - 1], + b1_gs_gemm1ns_gemm1ks_strides[NumDimG + NumDimO + NumDimN - 1]}, + c_mz_gemm1nz_strides_{c_gs_ms_gemm1ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_gemm1ns_strides[NumDimG + NumDimM + NumDimO - 1]}, + batch_count_{c1_grid_desc_g_m_n_.GetLength(I0)}, + compute_base_ptr_of_batch_{a_grid_desc_g_m_k_, + b_grid_desc_g_n_k_, + b1_grid_desc_g_n_k_, + c1_grid_desc_g_m_n_, + d0s_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc1_biases; + ignore = acc1_biases_gs_ms_gemm1ns_lengths; + ignore = acc1_biases_gs_ms_gemm1ns_strides; + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + using D0DataType = remove_cvref_t>; + // D0 pointer + p_d0s_grid_(i) = static_cast(p_acc0_biases[i]); + // for check + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_lengths[i][NumDimG + NumDimM]); + d0s_nl_ns_lengths_strides_[i].push_back( + acc0_biases_gs_ms_ns_strides[i][NumDimG + NumDimM]); + }); + + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + b1_grid_desc_bk0_n_bk1_, + c1_grid_desc_m_n_, + block_2_ctile_map_)) + { + c1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c1_grid_desc_m_n_); + + D0sGridDesc_M_N d0s_grid_desc_m_n{DeviceOp::MakeD0sGridDescriptor_M_N( + acc0_biases_gs_ms_ns_lengths, acc0_biases_gs_ms_ns_strides)}; + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_ = + GridwiseGemm::MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5( + d0s_grid_desc_m_n); + } + } + + void Print() const + { + std::cout << "a_grid_desc_g_m_k_: " << a_grid_desc_g_m_k_.GetLength(I0) << ", " + << a_grid_desc_g_m_k_.GetLength(I1) << ", " + << a_grid_desc_g_m_k_.GetLength(I2) << '\n'; + std::cout << "b_grid_desc_g_n_k_: " << b_grid_desc_g_n_k_.GetLength(I0) << ", " + << b_grid_desc_g_n_k_.GetLength(I1) << ", " + << b_grid_desc_g_n_k_.GetLength(I2) << '\n'; + std::cout << "b1_grid_desc_g_n_k_: " << b1_grid_desc_g_n_k_.GetLength(I0) << ", " + << b1_grid_desc_g_n_k_.GetLength(I1) << ", " + << b1_grid_desc_g_n_k_.GetLength(I2) << '\n'; + std::cout << "c1_grid_desc_g_m_n_: " << c1_grid_desc_g_m_n_.GetLength(I0) << ", " + << c1_grid_desc_g_m_n_.GetLength(I1) << ", " + << c1_grid_desc_g_m_n_.GetLength(I2) << '\n'; + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + typename GridwiseGemm::D0sGridPointer p_d0s_grid_; + + // tensor descriptor + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + C1GridDesc_M_N c1_grid_desc_m_n_; + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; + C1GridDesc_G_M_N c1_grid_desc_g_m_n_; + D0sGridDesc_G_M_N d0s_grid_desc_g_m_n_; + + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 + d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_; + + // block-to-c-tile map + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + C0DEElementwiseOperation c0de_element_op_; + B1ElementwiseOperation b1_element_op_; + C1DEElementwiseOperation c1de_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // For robust IsSupportedArgument() check + std::vector raw_lengths_mz_nz_kz_gemm1nz_; + std::vector a_mz_kz_strides_; + std::vector b_nz_kz_strides_; + std::vector b1_nz_kz_strides_; + std::vector c_mz_gemm1nz_strides_; + std::array, NumD0Tensor> d0s_nl_ns_lengths_strides_; + + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error("wrong! unsupported argument"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c1_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + typename GridwiseGemm::D0sGridPointer, + AElementwiseOperation, + BElementwiseOperation, + C0DEElementwiseOperation, + B1ElementwiseOperation, + C1DEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5, + typename GridwiseGemm::DefaultBlock2CTileMap, + ComputeBasePtrOfStridedBatch, + C0MatrixMask, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.p_d0s_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c0de_element_op_, + arg.b1_element_op_, + arg.c1de_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_, + arg.block_2_ctile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_, + arg.c0_matrix_mask_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + arg.Print(); + } + + if(!ck::is_xdl_supported()) + { + return false; + } + + // TODO ANT: Check if tensor specialization & strides mismatch + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c1_grid_desc_g_m_n_.GetLength(I0); // unpadded + const index_t c_m = arg.c1_grid_desc_m_n_.GetLength(I0); + const index_t c_gemm1n = arg.c1_grid_desc_m_n_.GetLength(I1); + const index_t a_m = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t b1_gemm1n = arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); + + if(!(c_g == arg.batch_count_ && c_m == a_m && c_gemm1n == b1_gemm1n)) + { + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; + const auto NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[2]; + const auto Gemm1NzRaw = arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; + const auto c_extent_lowest = Gemm1NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b_stride_lowest = + BBlockTransferSrcVectorDim == 2 ? arg.b_nz_kz_strides_[1] : arg.b_nz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_kz_strides_[1] : arg.b1_nz_kz_strides_[0]; + const auto c_stride_lowest = + arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be contiguous + + if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + return false; + } + for(int i = 0; i < NumD0Tensor; i++) + { + if(arg.d0s_nl_ns_lengths_strides_[i][1] == 1 && + arg.d0s_nl_ns_lengths_strides_[i][0] % D0sTransferSrcScalarPerVector != 0) + { + return false; + } + if(arg.d0s_nl_ns_lengths_strides_[i][1] != 1 && D0sTransferSrcScalarPerVector != 1) + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c1_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_strides, + const std::array, NumD1Tensor> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumD1Tensor> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + C0DEElementwiseOperation c0de_element_op, + B1ElementwiseOperation b1_element_op, + C1DEElementwiseOperation c1de_element_op) + { + return Argument{p_a, + p_b, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides, + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + a_element_op, + b_element_op, + c0de_element_op, + b1_element_op, + c1de_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + // FIXME: constness + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::vector& b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + const std::vector& b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + const std::vector& c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + const std::vector& c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_lengths, + const std::array, NumD0Tensor> acc0_biases_gs_ms_ns_strides, + const std::array, NumD1Tensor> + acc1_biases_gs_ms_gemm1ns_lengths, // acc1_biases_gs_ms_os_lengths + const std::array, NumD1Tensor> + acc1_biases_gs_ms_gemm1ns_strides, // acc1_biases_gs_ms_os_strides + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + C0DEElementwiseOperation c0de_element_op, + B1ElementwiseOperation b1_element_op, + C1DEElementwiseOperation c1de_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, // cast in struct Argument + p_acc1_biases, // cast in struct Argument + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + b1_gs_gemm1ns_gemm1ks_lengths, // b1_gs_os_ns_lengths + b1_gs_gemm1ns_gemm1ks_strides, // b1_gs_os_ns_strides + c_gs_ms_gemm1ns_lengths, // c_gs_ms_os_lengths + c_gs_ms_gemm1ns_strides, // c_gs_ms_os_strides + acc0_biases_gs_ms_ns_lengths, + acc0_biases_gs_ms_ns_strides, + acc1_biases_gs_ms_gemm1ns_lengths, + acc1_biases_gs_ms_gemm1ns_strides, + a_element_op, + b_element_op, + c0de_element_op, + b1_element_op, + c1de_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(BSpec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp new file mode 100755 index 000000000000..e846b0630b7f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp @@ -0,0 +1,1112 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef __HIPCC_RTC__ +#include +#include +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/masking_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const index_t batch_count, + const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, + const C0MatrixMask c0_matrix_mask) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map, + c0_matrix_mask); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = b1_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_ctile_map; + ignore = batch_count; + ignore = compute_base_ptr_of_batch; + ignore = c0_matrix_mask; +#endif // end of if (defined(__gfx9__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) + +// When using NPadding as GemmSpecialization, AccElementwiseOperation should be set to +// ScaleAndResetNaNToMinusInfinity. +// if !isNan(AccElement) +// AccElement *= scale +// else +// AccElement = -INFINITY +// Otherwise, result may be wrong. + +template +struct DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle + : public DeviceBatchedGemmSoftmaxGemm +{ + using DeviceOp = DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // Args: Gemm1KRaw, Gemm1NRaw, StrideB1 + static auto MakeB1GridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b1_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto b1_grid_desc_n_k = matrix_padder.PadB1Descriptor_N_K(b1_grid_desc_nraw_kraw); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideB1_(BatchStrideB1), + BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB1_); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideB1_; + index_t BatchStrideC_; + }; + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + using C0MatrixMask = conditional_t, + C0MatrixMask_impl>; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN, + MaskOutUpperTriangle>; + +#ifndef __HIPCC_RTC__ + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, // = ORaw + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + b1_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeB1GridDescriptor_BK0_N_BK1(NRaw, Gemm1NRaw, StrideB1)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, Gemm1NRaw, StrideC)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + batch_count_(Batch), + compute_base_ptr_of_batch_{BatchStrideA, BatchStrideB, BatchStrideB1, BatchStrideC}, + c0_matrix_mask_{NRaw}, + raw_lengths_m_n_k_o_{MRaw, NRaw, KRaw, Gemm1NRaw} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + b1_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + index_t batch_count_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // For robust IsSupportedArgument() check + std::vector raw_lengths_m_n_k_o_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.batch_count_; + + // Gemm0_K + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::B1GridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + ComputeBasePtrOfStridedBatch, + C0MatrixMask, + has_main_k_block_loop_>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.batch_count_, + arg.compute_base_ptr_of_batch_, + arg.c0_matrix_mask_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; +#endif + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static constexpr bool + IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_, index_t Gemm1NRaw_) + { + // check vector load/store + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B + if constexpr(is_same_v) + { + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of B1 + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(NRaw_ % B1BlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of C + if constexpr(is_same_v) + { + if(Gemm1NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else if constexpr(is_same_v) + { + if(MRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + + return true; + } + +#ifndef __HIPCC_RTC__ + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + const auto MRaw = arg.raw_lengths_m_n_k_o_[0]; + const auto NRaw = arg.raw_lengths_m_n_k_o_[1]; + const auto KRaw = arg.raw_lengths_m_n_k_o_[2]; + const auto Gemm1NRaw = arg.raw_lengths_m_n_k_o_[3]; + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.b1_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_) and + IsSupported(MRaw, NRaw, KRaw, Gemm1NRaw); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const B1DataType* p_b1, + CDataType* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, p_b, p_b1, p_c, MRaw, + NRaw, KRaw, Gemm1NRaw, Batch, StrideA, + StrideB, StrideB1, StrideC, BatchStrideA, BatchStrideB, + BatchStrideB1, BatchStrideC, a_element_op, b_element_op, acc_element_op, + b1_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_b1, + void* p_c, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t Gemm1NRaw, + index_t Batch, + index_t StrideA, + index_t StrideB, + index_t StrideB1, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideB1, + index_t BatchStrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_b1), + static_cast(p_c), + MRaw, + NRaw, + KRaw, + Gemm1NRaw, + Batch, + StrideA, + StrideB, + StrideB1, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideB1, + BatchStrideC, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ">"; + // clang-format on + + return str.str(); + } +#endif + + template + struct Descriptor + { + template + static constexpr auto MakeAGridDescriptor_AK0_M_AK1(const AGridDescriptor& a_grid_desc) + { + const auto a_grid_desc_m_k = DeviceOp::matrix_padder.PadADescriptor_M_K(a_grid_desc); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeBGridDescriptor_BK0_N_BK1(const BGridDescriptor& b_grid_desc) + { + const auto b_grid_desc_n_k = DeviceOp::matrix_padder.PadBDescriptor_N_K(b_grid_desc); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeB1GridDescriptor_BK0_N_BK1(const B1GridDescriptor& b1_grid_desc) + { + const auto b1_grid_desc_n_k = DeviceOp::matrix_padder.PadB1Descriptor_N_K(b1_grid_desc); + + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static constexpr auto MakeCGridDescriptor_M_N(const CGridDescriptor& c_grid_desc) + { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(c_grid_desc); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using B1GridDesc_BK0_N_BK1 = + remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + matrix_padder.PadN, + MaskOutUpperTriangle>; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1; + CGridDesc_M_N c_grid_desc_m_n; + C0MatrixMask c0_matrix_mask; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_descriptor_mblock_mperblock_nblock_nperblock; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + B1ElementwiseOperation b1_element_op; + CElementwiseOperation c_element_op; + + bool has_main_k_block_loop = true; + bool is_valid = false; + + constexpr Descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + B1ElementwiseOperation b1_element_op_, + CElementwiseOperation c_element_op_) + : a_grid_desc_ak0_m_ak1{MakeAGridDescriptor_AK0_M_AK1(a)}, + b_grid_desc_bk0_n_bk1{MakeBGridDescriptor_BK0_N_BK1(b)}, + b1_grid_desc_bk0_n_bk1{MakeB1GridDescriptor_BK0_N_BK1(b1)}, + c_grid_desc_m_n{MakeCGridDescriptor_M_N(c)}, + block_2_ctile_map{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n)}, + c_grid_descriptor_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + c0_matrix_mask{c.GetLength(I1)}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + b1_element_op{b1_element_op_}, + c_element_op{c_element_op_}, + is_valid{GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_m_n, + block_2_ctile_map) and + IsSupported(a_grid_desc_ak0_m_ak1.GetLength(I1), + b_grid_desc_bk0_n_bk1.GetLength(I1), + a_grid_desc_ak0_m_ak1.GetLength(I0) * + a_grid_desc_ak0_m_ak1.GetLength(I2), + b1_grid_desc_bk0_n_bk1.GetLength(I1))} + { + } + + constexpr bool IsValid() const { return is_valid; } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + B1Desc b1, + CDesc c, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + B1ElementwiseOperation b1_element_op = B1ElementwiseOperation{}, + CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + return Descriptor( + a, b, b1, c, a_element_op, b_element_op, b1_element_op, c_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const float scale, + const ADataType* __restrict__ p_a_grid, + const ADataType* __restrict__ p_b_grid, + const ADataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid) + { +#ifndef __HIPCC_RTC__ + assert(desc.is_valid); +#endif + __shared__ char p_shared_block[Desc::GridwiseGemm::GetSharedMemoryNumberOfByte()]; + AccElementwiseOperation acc_element_op{scale}; + + if(desc.has_main_k_block_loop) + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + else + { + Desc::GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + acc_element_op, + desc.b1_element_op, + desc.c_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.b1_grid_desc_bk0_n_bk1, + desc.c_grid_descriptor_mblock_mperblock_nblock_nperblock, + desc.block_2_ctile_map, + desc.c0_matrix_mask); + } + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp new file mode 100755 index 000000000000..abd6574d8cc4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,758 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_wmma_cshuffle_v3( + typename GridwiseGemm::Argument + karg, // This works for now but it actually receives a + // DeviceBatchedGemm_Wmma_CShuffleV3::Argument + // argument through implicit conversion to base class! + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + // The normal approach to batching would be to increase the grid size by just stretching out + // the grid Z dimension (which is the outermost dimension), but this depends on lower level + // functions not directly using the Z dimension for other calculations. As it turns out, k + // batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now + // we will use the grid Y dimension for batching. This may be a bit fragile. + const index_t g_idx = amd_wave_read_first_lane(blockIdx.y); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset + a_batch_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset + b_batch_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, + p_shared, + karg); +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; + ignore = compute_ptr_offset_of_batch; +#endif +} + +/// @brief \"Universal\" Batched GEMM operation without SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{G,M,N} = C_op(A_op(A{G,M,K}) * B_op(B{G,K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through its design +/// and versatilty. +/// +/// @note This Kernel implementation currently does not support the SplitK algorithm. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). Currently not supported! +template +struct DeviceBatchedGemm_Wmma_CShuffleV3 : public DeviceBatchedGemm +{ + // We are inheriting from DeviceBatchedGemm and this base class does not support permuteA and + // permuteB arguments so for now we are not including this functionality. + static_assert(PermuteA == false, + "Permute A functionality not supported by DeviceBatchedGemm operations.\n"); + static_assert(PermuteB == false, + "Permute B functionality not supported by DeviceBatchedGemm operations.\n"); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by DeviceBatchedGemm base class. + false>; // PermuteB not supported by DeviceBatchedGemm base class. + + // Argument + struct Argument : public GridwiseGemm::Argument + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t Batch_, + index_t k_batch_, + bool is_reduce_ = false) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + is_reduce_), + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA_, BatchStrideB_, BatchStrideC_} + { + } + + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + }; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + // The normal approach to batching would be to increase the grid size by just stretching + // out the grid Z dimension (which is the outermost dimension), but this depends on + // lower level functions not directly using the Z dimension for other calculations. As + // it turns out, k batching does rely directly on blockIdx.Z through SplitKBatchOffset. + // Therefore, for now we will use the grid Y dimension for batching. This may be a bit + // fragile. + gdy *= arg.Batch; + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + // Packed sizes are 1 for all implemented data types but we include it anyway + // for future compatibility. + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + // Note: the grid descriptors and size_a / size_b do *not* take batching into + // account, so we have to manually multiply overall buffer sizes for rotating + // memory by batch. + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.Batch * size_a_buffer, + arg_.Batch * size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.Batch * arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_, + arg_.compute_ptr_offset_of_batch); + } + else + { + auto clear_workspace = [&]() { + // clear c mem + if(arg.KBatch > 1) + // Note: we multiply by batch since we want to clear the C matrix for + // the whole batch. Untested since we don't have k batching ATM. + // Note: This seems incorrect for non-contiguous memory layouts for C + // (padding, gaps). + HIP_CHECK_ERROR( + hipMemsetAsync(arg.p_c_grid, + 0, + arg.Batch * arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg, + arg.compute_ptr_offset_of_batch); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_wmma_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // TODO: This is not part of the DeviceBatchedGemm base class but it was part of + // DeviceBatchedGemmV2. Remove? + // index_t GetKPerBlock() override { return KPerBlock; } + // bool GetPermuteA() override { return PermuteA; } + // bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1 /* KBatch */}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch, + 1); // KBatch + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceBatchedGemm_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct DeviceBatchedGemmXdl : public DeviceBatchedGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpecialization::MNKPadding, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using Problem = typename GridwiseGemm::Problem; + + // Argument + struct Argument : public Problem, public BaseArgument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + Batch(Batch_), + compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC} + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceBatchedGemmXdl::Argument; + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); + } + + auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + gdx *= karg.Batch; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = + kernel_batched_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Problem& problem) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(problem); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t Batch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + BatchStrideA, + BatchStrideB, + BatchStrideC, + Batch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceBatchedGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ">" + << " NumGemmKPrefetchStage: " + << NumGemmKPrefetchStage << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp new file mode 100755 index 000000000000..7d9555dc822a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp @@ -0,0 +1,1015 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t g_idx = blockIdx.z % karg.Batch; + const index_t k_idx = blockIdx.z / karg.Batch; + + const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + const auto b_scale_batch_offset = karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx); + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchedGemm_Xdl_CShuffleV3_BScale + : public DeviceBatchedGemmV2BScale +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideC_(BatchStrideC), + BatchStrideScaleB_(BatchStrideScaleB) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_) / BPackedSize; + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + __host__ __device__ constexpr long_index_t GetSacleBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideScaleB_); + } + + private: + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; + index_t BatchStrideScaleB_; + }; + + struct Argument : public GridwiseGemm::Argument + { + index_t Batch; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch; + + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t BatchStrideA_, + index_t BatchStrideB_, + index_t BatchStrideC_, + index_t BatchStrideScaleB_, + const BScaleDataType* p_b_scale_grid_, + index_t Batch_, + index_t KBatch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + StrideScaleB_, + p_b_scale_grid_, + KBatch_, // KBatch + a_element_op_, + b_element_op_, + c_element_op_), + Batch(Batch_), + compute_ptr_offset_of_batch( + BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_) + { + } + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_batched_gemm_b_scale_xdl_cshuffle_v3< + GridwiseGemm, + Argument, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const BScaleDataType* p_b_scale, + index_t Batch, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + p_b_scale, + Batch, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + index_t BatchStrideA, + index_t BatchStrideB, + index_t BatchStrideC, + index_t BatchStrideScaleB, + const void* p_b_scale, + index_t Batch, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + BatchStrideA, + BatchStrideB, + BatchStrideC, + BatchStrideScaleB, + static_cast(p_b_scale), + Batch, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_backward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0 && + MThreadSliceSize % DxDstVectorSize == 0) || + (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0 && + KThreadSliceSize % DxDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = + make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + using MeanVarGridDesc_M = ScaleBiasGridDesc_M; + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array dyStrides, + const std::array dxStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const DyDataType* p_dy, + const ScaleDataType* p_scale, + const MeanVarDataType* p_savedMean, + const MeanVarDataType* p_savedInvVar, + const DyElementwiseOp dy_elementwise_op, + double epsilon, + DxDataType* p_dx, + DscaleDbiasDataType* p_dscale, + DscaleDbiasDataType* p_dbias) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnDscaleDbiasStrides_(bnDscaleDbiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_dy_(p_dy), + p_scale_(p_scale), + p_savedMean_(p_savedMean), + p_savedInvVar_(p_savedInvVar), + dy_elementwise_op_(dy_elementwise_op), + p_dx_(p_dx), + p_dscale_(p_dscale), + p_dbias_(p_dbias) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + dyStrides_ = + shuffle_tensor_dimensions(dyStrides, reduceDims); + dxStrides_ = + shuffle_tensor_dimensions(dxStrides, reduceDims); + + std::tie(invariant_length, reduce_length) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + + haveSavedMeanInvVar_ = (p_savedMean_ != nullptr && p_savedInvVar_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = (reduce_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize = (invariant_length + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize; + + x_grid_desc_m_k = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize, numBlockTileIteration); + dy_grid_desc_m_k = + MakeXY2dDescriptor(xyLengths_, dyStrides_, blkGroupSize, numBlockTileIteration); + dx_grid_desc_m_k = + MakeXY2dDescriptor(xyLengths_, dxStrides_, blkGroupSize, numBlockTileIteration); + scale_grid_desc_m = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides); + dscale_dbias_grid_desc_m = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides); + mean_var_grid_desc_m = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides); + } + + AccDataType epsilon_; + + bool haveSavedMeanInvVar_; + + std::array xyLengths_; + std::array xStrides_; + std::array dyStrides_; + std::array dxStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnDscaleDbiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const DyDataType* p_dy_; + const ScaleDataType* p_scale_; + const MeanVarDataType* p_savedMean_; + const MeanVarDataType* p_savedInvVar_; + const DyElementwiseOp dy_elementwise_op_; + DxDataType* p_dx_; + DscaleDbiasDataType* p_dscale_; + DscaleDbiasDataType* p_dbias_; + + long_index_t invariant_length; + long_index_t reduce_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + XYGridDesc_M_K x_grid_desc_m_k; + XYGridDesc_M_K dy_grid_desc_m_k; + XYGridDesc_M_K dx_grid_desc_m_k; + ScaleBiasGridDesc_M scale_grid_desc_m; + ScaleBiasGridDesc_M dscale_dbias_grid_desc_m; + MeanVarGridDesc_M mean_var_grid_desc_m; + + void* workspace_mean; + void* workspace_variance; + void* workspace_count; + + void* workspace_savedMean; + void* workspace_savedInvVar; + + void* workspace_reduce_dscale; + void* workspace_reduce_dbias; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize > 1) + { + // workspace for the partial reduced result for dscale + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64; + + // workspace for the partial reduced result for dbias + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64; + + if(!pArg_->haveSavedMeanInvVar_) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t) + 64; + + // workspace for welford result mean + workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64; + + // workspace for welford result inv_variance + workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64; + }; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + index_t space_sz; + + // setup buffer for the partial reduced result for dscale + pArg_->workspace_reduce_dscale = pArg_->p_workspace_; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for the partial reduced result for dbias + pArg_->workspace_reduce_dbias = + reinterpret_cast(pArg_->workspace_reduce_dscale) + space_sz; + + if(UseMultiblockInK && pArg_->blkGroupSize > 1) + { + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford intermediate mean + pArg_->workspace_mean = + reinterpret_cast(pArg_->workspace_reduce_dbias) + space_sz; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford intermediate varirance + pArg_->workspace_variance = reinterpret_cast(pArg_->workspace_mean) + space_sz; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford intermediate count + pArg_->workspace_count = reinterpret_cast(pArg_->workspace_variance) + space_sz; + + space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford result mean + pArg_->workspace_savedMean = reinterpret_cast(pArg_->workspace_count) + space_sz; + + space_sz = pArg_->invariant_length * sizeof(MeanVarDataType); + space_sz = math::integer_least_multiple(space_sz, 64); + + // setup buffer for welford result inv_variance + pArg_->workspace_savedInvVar = + reinterpret_cast(pArg_->workspace_savedMean) + space_sz; + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + const auto dscale_dbias_grid_desc_m_g = + DeviceBatchNormBwdImpl::MakeMultiblockFirstReduceOutputMG2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + const auto dscale_dbias_grid_desc_m_k = + DeviceBatchNormBwdImpl::MakeMultiblockFinalReduceInputMK2dDescriptor( + arg.invariant_length, arg.blkGroupSize); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g); + using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k); + + using GridwiseWelfordSecondHalfReduceFirstHalf_ = + GridwiseWelfordSecondHalfReduceFirstHalf; + + using GridwiseReduceSecondHalfBatchNormBwdFinal_ = + GridwiseReduceSecondHalfBatchNormBackwardFinal; + + if(UseMultiblockInK && arg.blkGroupSize > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize, arg.numBlockTileIteration, arg.reduce_length); + + if(!arg.haveSavedMeanInvVar_) + { + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration, + arg.p_x_, + static_cast(arg.workspace_mean), + static_cast(arg.workspace_variance), + static_cast(arg.workspace_count)); + }; + + const auto kern_welford_second_half_reduce_first_half = + kernel_welford_second_half_reduce_first_half< + GridwiseWelfordSecondHalfReduceFirstHalf_, + XDataType, + DyDataType, + AccDataType, + ScaleDataType, + DscaleDbiasDataType, + MeanVarDataType, + DyElementwiseOp, + XYGridDesc_M_K, + MeanVarGridDesc_M, + MeanVarCountGridDesc_M_K, + DscaleDbiasGridDesc_M_G>; + + const auto kern_reduce_second_half_batchnorm_backward_final = + kernel_reduce_second_half_batchnorm_backward_final< + GridwiseReduceSecondHalfBatchNormBwdFinal_, + XDataType, + DyDataType, + DxDataType, + ScaleDataType, + DscaleDbiasDataType, + MeanVarDataType, + DyElementwiseOp, + XYGridDesc_M_K, + DscaleDbiasGridDesc_M_K, + MeanVarGridDesc_M, + ScaleBiasGridDesc_M>; + + index_t numDscaleDbiasBlockTileIteration = + (arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize; + + avg_time += launch_and_time_kernel( + stream_config, + kern_welford_second_half_reduce_first_half, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + arg.dy_grid_desc_m_k, + arg.mean_var_grid_desc_m, + mean_var_count_grid_desc_m_k, + dscale_dbias_grid_desc_m_g, + arg.blkGroupSize, + arg.numBlockTileIteration, + numDscaleDbiasBlockTileIteration, + arg.epsilon_, + arg.haveSavedMeanInvVar_, + arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr, + arg.haveSavedMeanInvVar_ ? arg.p_savedInvVar_ : nullptr, + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_mean), + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_variance), + arg.haveSavedMeanInvVar_ ? nullptr + : static_cast(arg.workspace_count), + arg.dy_elementwise_op_, + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_savedMean), + arg.haveSavedMeanInvVar_ + ? nullptr + : static_cast(arg.workspace_savedInvVar), + arg.p_x_, + arg.p_dy_, + static_cast(arg.workspace_reduce_dscale), + static_cast(arg.workspace_reduce_dbias)); + + avg_time += launch_and_time_kernel( + stream_config, + kern_reduce_second_half_batchnorm_backward_final, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + arg.dy_grid_desc_m_k, + arg.dx_grid_desc_m_k, + dscale_dbias_grid_desc_m_k, + arg.mean_var_grid_desc_m, + arg.scale_grid_desc_m, + arg.dscale_dbias_grid_desc_m, + arg.blkGroupSize, + arg.reduce_length, + arg.numBlockTileIteration, + numDscaleDbiasBlockTileIteration, + static_cast(arg.workspace_reduce_dscale), + static_cast(arg.workspace_reduce_dbias), + arg.haveSavedMeanInvVar_ + ? arg.p_savedMean_ + : static_cast(arg.workspace_savedMean), + arg.haveSavedMeanInvVar_ + ? arg.p_savedInvVar_ + : static_cast(arg.workspace_savedInvVar), + arg.p_x_, + arg.p_dy_, + arg.p_scale_, + arg.dy_elementwise_op_, + arg.p_dx_, + arg.p_dscale_, + arg.p_dbias_); + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration, arg.reduce_length); + + using GridwiseBatchNormBackwardWithBlockwiseWelford_ = + GridwiseBatchNormBackwardWithBlockwiseWelford; + + const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford< + GridwiseBatchNormBackwardWithBlockwiseWelford_, + XDataType, + DyDataType, + DxDataType, + AccDataType, + ScaleDataType, + DscaleDbiasDataType, + MeanVarDataType, + DyElementwiseOp, + XYGridDesc_M_K, + ScaleBiasGridDesc_M, + MeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_bwd, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k, + arg.dy_grid_desc_m_k, + arg.dx_grid_desc_m_k, + arg.scale_grid_desc_m, + arg.dscale_dbias_grid_desc_m, + arg.mean_var_grid_desc_m, + get_reduce_count_per_thread, + arg.reduce_length, + arg.numBlockTileIteration, + arg.epsilon_, + arg.p_x_, + arg.p_dy_, + arg.p_scale_, + arg.haveSavedMeanInvVar_, + arg.p_savedMean_, + arg.p_savedInvVar_, + arg.dy_elementwise_op_, + arg.p_dx_, + arg.p_dscale_, + arg.p_dbias_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XDyDxVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->dyStrides_[NumInvariantDim - 1] != 1 || + pArg_->dxStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % DySrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % DxDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->dyStrides_[Rank - 1] != 1 || + pArg_->dxStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % DySrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % DxDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + + if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0) + return false; + + if(pArg_->haveSavedMeanInvVar_) + { + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcVectorSize != 0) + return false; + }; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::array xyLengths, + const std::array xStrides, + const std::array dyStrides, + const std::array dxStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnDscaleDbiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_dy, + const void* p_scale, + const void* p_savedMean, + const void* p_savedInvVar, + double epsilon, + const DyElementwiseOp dy_elementwise_op, + void* p_dx, + void* p_dscale, + void* p_dbias) override + { + return std::make_unique(xyLengths, + xStrides, + dyStrides, + dxStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnDscaleDbiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_dy), + static_cast(p_scale), + static_cast(p_savedMean), + static_cast(p_savedInvVar), + dy_elementwise_op, + epsilon, + static_cast(p_dx), + static_cast(p_dscale), + static_cast(p_dbias)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormBwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XDyDxVectorDim_" << XDyDxVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp new file mode 100755 index 000000000000..e7e4668d92e7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp @@ -0,0 +1,824 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = make_naive_tensor_descriptor( + make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = make_naive_tensor_descriptor( + make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* p_scale, + const BiasDataType* p_bias, + const YElementwiseOp y_elementwise_op, + double epsilon, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_scale_(p_scale), + p_bias_(p_bias), + y_elementwise_op_(y_elementwise_op), + p_y_(p_y), + resultSaveMean_(resultSaveMean), + resultSaveInvVariance_(resultSaveInvVariance), + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = + shuffle_tensor_dimensions(yStrides, reduceDims); + + std::tie(invariant_length_, reduce_length_) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + + updateMovingAverage_ = + (resultRunningMean != nullptr && resultRunningVariance != nullptr); + saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 16 + if(testBlkGroupSize <= 16) + break; + + iterations++; + }; + + blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration_ = iterations; + } + else + { + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_; + + x_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + scale_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_); + bias_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_); + mean_var_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_); + } + + AccDataType epsilon_; + AccDataType averageFactor_; + + bool updateMovingAverage_; + bool saveMeanInvVariance_; + + std::array xyLengths_; + std::array xStrides_; + std::array yStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnBiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const ScaleDataType* p_scale_; + const BiasDataType* p_bias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; + + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; + + long_index_t invariant_length_; + long_index_t reduce_length_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + XYGridDesc_M_K x_grid_desc_m_k_; + XYGridDesc_M_K y_grid_desc_m_k_; + ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_; + + void* workspace_mean_; + void* workspace_variance_; + void* workspace_count_; + + void* control_; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + + // workspace for barrier objects, each barrier object consists of two integers + // TODO: allocate barrier object memory globally to reuse it by other operators + workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * + sizeof(int) * 2; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + // setup buffer used for intermediate welford mean + pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->workspace_variance_ = + reinterpret_cast(pArg_->workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welfor count + pArg_->workspace_count_ = + reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + + index_t count_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t); + + count_space_sz = math::integer_least_multiple(count_space_sz, 64); + + pArg_->control_ = reinterpret_cast(pArg_->workspace_count_) + count_space_sz; + + index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) / + M_BlockTileSize * sizeof(int) * 2; + + hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz)); + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(UseMultiblockInK && arg.blkGroupSize_ > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_); + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + + using GridwiseMultiblockBatchNormForward_ = + GridwiseMultiblockBatchNormForward; + + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + using GridwiseWelfordSecondHalfBatchNormForwardFinal_ = + GridwiseWelfordSecondHalfBatchNormForwardFinal; + + // It is found that: + // 1) gfx1030 does not support the GLC enabled vector load/store, so using the + // two-kernel method for gfx1030 + // 2) Profiler on gfx908 could hang even though it works when running examples + // 3) Single-kernel method works on gfx1100, but the performance it not better + // than two-kernel method (due to more warps participating the barrier) + if(ck::get_device_name() == "gfx90a") + { + const auto kern_multiblock_batchnorm_fwd_ = + kernel_multiblock_batchnorm_forward; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_batchnorm_fwd_, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, // for writing to mean/variance/count + // workspace by multiple workgroups + mean_var_count_grid_desc_m_k, // for reading from mean/variance/count + // workspace by each workgroup + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + static_cast(arg.control_), + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += launch_and_time_kernel( + stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += launch_and_time_kernel( + stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration_, arg.reduce_length_); + + using GridwiseBatchNormForwardWithBlockwiseWelford_ = + GridwiseBatchNormForwardWithBlockwiseWelford; + + const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford< + GridwiseBatchNormForwardWithBlockwiseWelford_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_fwd, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XSrcYDstVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->yStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0) + return false; + + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0) + return false; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_scale, + const void* p_bias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_scale), + static_cast(p_bias), + y_elementwise_op, + epsilon, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormFwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp new file mode 100755 index 000000000000..c3e0837722cf --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl_obsolete.hpp @@ -0,0 +1,716 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/device/welford_helper.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp" +#include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeXY2dDescriptor(const std::array& xyLengths, + const std::array& xyStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleXYLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + const auto tupleXYStrides = + generate_tuple([&](auto I) { return xyStrides[I]; }, Number{}); + + const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; }, + Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return xyLengths[I]; }, Number{}); + + return transform_tensor_descriptor(raw_grid_desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = workSizePerBlock * blkGroupSize - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto grid_desc_m_g = make_naive_tensor_descriptor( + make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_g_padded = + transform_tensor_descriptor(grid_desc_m_g, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_pass_through_transform(blkGroupSize)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_g_padded); + }; + + static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize) + { + const auto reduceLength = blkGroupSize; + const auto grid_desc_m_k = make_naive_tensor_descriptor( + make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength)); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto kPad = + math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, mPad), + make_right_pad_transform(reduceLength, kPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (grid_desc_m_k_padded); + }; + + static auto + MakeScaleBiasMeanVar1dDescriptor(const std::array& lengths, + const std::array& strides) + { + const auto tupleLengths = + generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + const auto tupleStrides = + generate_tuple([&](auto I) { return strides[I]; }, Number{}); + + auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + auto grid_desc_m = transform_tensor_descriptor( + raw_grid_desc, + make_tuple(make_merge_transform(tupleLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + + const auto mPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = + transform_tensor_descriptor(grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, mPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (grid_desc_m_padded); + }; + + using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)); + using ScaleBiasMeanVarGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const XDataType* p_x, + const ScaleDataType* p_scale, + const BiasDataType* p_bias, + const YElementwiseOp y_elementwise_op, + double epsilon, + YDataType* p_y, + MeanVarDataType* resultSaveMean, + MeanVarDataType* resultSaveInvVariance, + double averageFactor, + MeanVarDataType* resultRunningMean, + MeanVarDataType* resultRunningVariance) + : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths), + bnScaleStrides_(bnScaleStrides), + bnBiasStrides_(bnBiasStrides), + bnMeanVarStrides_(bnMeanVarStrides), + p_x_(p_x), + p_scale_(p_scale), + p_bias_(p_bias), + y_elementwise_op_(y_elementwise_op), + p_y_(p_y), + resultSaveMean_(resultSaveMean), + resultSaveInvVariance_(resultSaveInvVariance), + resultRunningMean_(resultRunningMean), + resultRunningVariance_(resultRunningVariance) + { + xyLengths_ = + shuffle_tensor_dimensions(xyLengths, reduceDims); + xStrides_ = + shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = + shuffle_tensor_dimensions(yStrides, reduceDims); + + std::tie(invariant_length_, reduce_length_) = + get_2d_lengths(xyLengths_); + + epsilon_ = type_convert(epsilon); + averageFactor_ = type_convert(averageFactor); + + updateMovingAverage_ = + (resultRunningMean != nullptr && resultRunningVariance != nullptr); + saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr); + + if(UseMultiblockInK) + { + int iterations = 1; + while(true) + { + int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 16 + if(testBlkGroupSize <= 16) + break; + + iterations++; + }; + + blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration_ = iterations; + } + else + { + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize_ = (invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize * blkGroupSize_; + + x_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeXY2dDescriptor(xyLengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + scale_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_); + bias_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_); + mean_var_grid_desc_m_ = + MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_); + } + + AccDataType epsilon_; + AccDataType averageFactor_; + + bool updateMovingAverage_; + bool saveMeanInvVariance_; + + std::array xyLengths_; + std::array xStrides_; + std::array yStrides_; + + std::array bnScaleBiasMeanVarLengths_; + std::array bnScaleStrides_; + std::array bnBiasStrides_; + std::array bnMeanVarStrides_; + + const XDataType* p_x_; + const ScaleDataType* p_scale_; + const BiasDataType* p_bias_; + const YElementwiseOp y_elementwise_op_; + YDataType* p_y_; + + MeanVarDataType* resultSaveMean_; + MeanVarDataType* resultSaveInvVariance_; + + MeanVarDataType* resultRunningMean_; + MeanVarDataType* resultRunningVariance_; + + long_index_t invariant_length_; + long_index_t reduce_length_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + XYGridDesc_M_K x_grid_desc_m_k_; + XYGridDesc_M_K y_grid_desc_m_k_; + ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_; + ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_; + + void* workspace_mean_; + void* workspace_variance_; + void* workspace_count_; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + // workspace for welford intermediate mean + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64; + } + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + if(UseMultiblockInK && pArg_->blkGroupSize_ > 1) + { + + // setup buffer used for intermediate welford mean + pArg_->workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->workspace_variance_ = + reinterpret_cast(pArg_->workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = + pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType); + + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welfor count + pArg_->workspace_count_ = + reinterpret_cast(pArg_->workspace_variance_) + variance_space_sz; + }; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(UseMultiblockInK && arg.blkGroupSize_ > 1) + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForMultiblockWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.blkGroupSize_, arg.numBlockTileIteration_, arg.reduce_length_); + + const auto mean_var_count_grid_desc_m_g = + DeviceBatchNormFwdImpl::MakeMeanVarCountOutputMG2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + const auto mean_var_count_grid_desc_m_k = + DeviceBatchNormFwdImpl::MakeMeanVarCountInputMK2dDescriptor( + arg.invariant_length_, arg.blkGroupSize_); + + using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g); + using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k); + + using GridwiseMultiblockWelfordFirstHalf_ = + GridwiseMultiblockWelfordFirstHalf; + + using GridwiseWelfordSecondHalfBatchNormForwardFinal_ = + GridwiseWelfordSecondHalfBatchNormForwardFinal; + + const auto kern_multiblock_welford_first_half = + kernel_multiblock_welford_first_half; + + const auto kern_welford_second_half_batchnorm_forward_final = + kernel_welford_second_half_batchnorm_forward_final< + GridwiseWelfordSecondHalfBatchNormForwardFinal_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + MeanVarCountGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M>; + + avg_time += + launch_and_time_kernel(stream_config, + kern_multiblock_welford_first_half, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_)); + + avg_time += + launch_and_time_kernel(stream_config, + kern_welford_second_half_batchnorm_forward_final, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + mean_var_count_grid_desc_m_k, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + arg.blkGroupSize_, + arg.numBlockTileIteration_, + arg.epsilon_, + static_cast(arg.workspace_mean_), + static_cast(arg.workspace_variance_), + static_cast(arg.workspace_count_), + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + } + else + { + using GetReduceCountPerThreadFunctor = + GetReduceCountPerThreadForBlockwiseWelford; + + GetReduceCountPerThreadFunctor get_reduce_count_per_thread( + arg.numBlockTileIteration_, arg.reduce_length_); + + using GridwiseBatchNormForwardWithBlockwiseWelford_ = + GridwiseBatchNormForwardWithBlockwiseWelford; + + const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford< + GridwiseBatchNormForwardWithBlockwiseWelford_, + XDataType, + YDataType, + AccDataType, + ScaleDataType, + BiasDataType, + MeanVarDataType, + YElementwiseOp, + XYGridDesc_M_K, + ScaleBiasMeanVarGridDesc_M, + ScaleBiasMeanVarGridDesc_M, + GetReduceCountPerThreadFunctor>; + + avg_time += launch_and_time_kernel(stream_config, + kern_batchnorm_fwd, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.scale_grid_desc_m_, + arg.bias_grid_desc_m_, + arg.mean_var_grid_desc_m_, + get_reduce_count_per_thread, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_scale_, + arg.p_bias_, + arg.y_elementwise_op_, + arg.p_y_, + arg.updateMovingAverage_, // true or false + arg.averageFactor_, + arg.resultRunningMean_, + arg.resultRunningVariance_, + arg.saveMeanInvVariance_, // true or false + arg.resultSaveMean_, + arg.resultSaveInvVariance_); + }; + + return (avg_time); + }; + + float Run(const BaseArgument* pArg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(pArg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* pArg) override + { + const Argument* pArg_ = dynamic_cast(pArg); + + if constexpr(XSrcYDstVectorDim == 0) + { + if(pArg_->xStrides_[NumInvariantDim - 1] != 1 || + pArg_->yStrides_[NumInvariantDim - 1] != 1) + return false; + + if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0) + return false; + } + else + { + if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1) + return false; + + if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 || + pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1) + return false; + if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0) + return false; + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0) + return false; + + if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1) + return false; + + if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0) + return false; + + bool is_valid = true; + + static_for<0, NumInvariantDim, 1>{}([&](auto I) { + if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I]) + is_valid = false; + }); + + if(!is_valid) + return false; + + return true; + }; + + std::unique_ptr MakeArgumentPointer( + const std::array xyLengths, + const std::array xStrides, + const std::array yStrides, + const std::array reduceDims, + const std::array bnScaleBiasMeanVarLengths, + const std::array bnScaleStrides, + const std::array bnBiasStrides, + const std::array bnMeanVarStrides, + const void* p_x, + const void* p_scale, + const void* p_bias, + double epsilon, + const YElementwiseOp y_elementwise_op, + void* p_y, + void* resultSaveMean, + void* resultSaveInvVariance, + double averageFactor, + void* resultRunningMean, + void* resultRunningVariance) override + { + return std::make_unique(xyLengths, + xStrides, + yStrides, + reduceDims, + bnScaleBiasMeanVarLengths, + bnScaleStrides, + bnBiasStrides, + bnMeanVarStrides, + static_cast(p_x), + static_cast(p_scale), + static_cast(p_bias), + y_elementwise_op, + epsilon, + static_cast(p_y), + static_cast(resultSaveMean), + static_cast(resultSaveInvVariance), + averageFactor, + static_cast(resultRunningMean), + static_cast(resultRunningVariance)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceBatchNormFwdImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp new file mode 100755 index 000000000000..8843e520a606 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp @@ -0,0 +1,625 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_cgemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ALayout, + typename BLayout, + typename CLayout, + typename ADataType, + typename BDataType, + typename CDataType, + typename GemmAccDataType, + typename CShuffleDataType, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t NumGemmKPrefetchStage, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1, + index_t BK1, + index_t MPerXDL, + index_t NPerXDL, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_AK1, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_BK1, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopScheduler LoopSched = make_default_loop_scheduler(), + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceCGemm_4Gemm_Xdl_CShuffle + : public DeviceCGemm +{ + using DeviceOp = DeviceCGemm_4Gemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t MPerThread = + MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t NPerThread = + NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto AScalarPerVector = Number<4>{}; + static constexpr auto BScalarPerVector = Number<4>{}; + static constexpr auto CScalarPerVector = Number<4>{}; + + template + static auto PadDescriptor_M_N(Desc_M_N desc) + { + const auto M = desc.GetLength(I0); + const auto N = desc.GetLength(I1); + const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M; + const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N; + + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_right_pad_transform(M, pad_M), make_right_pad_transform(N, pad_N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return padded_desc; + } + + static auto MakeDescriptor_M_N(const std::vector& lengths, + const std::vector& strides) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{}); + auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + return PadDescriptor_M_N(desc); + } + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1})); + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public GridwiseGemm::Problem + { + using Problem = typename GridwiseGemm::Problem; + + Argument(const ADataType* p_a_grid_real_, + const ADataType* p_a_grid_imag_, + const BDataType* p_b_grid_real_, + const BDataType* p_b_grid_imag_, + CDataType* p_c_grid_real_, + CDataType* p_c_grid_imag_, + CDataType* p_workspace, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid_real{p_a_grid_real_}, + p_a_grid_imag{p_a_grid_imag_}, + p_b_grid_real{p_b_grid_real_}, + p_b_grid_imag{p_b_grid_imag_}, + p_c_grid_real{p_c_grid_real_}, + p_c_grid_imag{p_c_grid_imag_}, + p_aux_grid{p_workspace} + { + if constexpr(is_same::value) + { + c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1}); + } + else if constexpr(is_same::value) + { + c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_}); + } + + p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_); + } + + // private: + const ADataType* p_a_grid_real; + const ADataType* p_a_grid_imag; + const BDataType* p_b_grid_real; + const BDataType* p_b_grid_imag; + CDataType* p_c_grid_real; + CDataType* p_c_grid_imag; + CDataType* p_aux_grid; + CDataType* p_aux_2_grid; + CGridDesc_M_N c_grid_desc_m_n; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; + + float ave_time = 0; + + using Add = ck::tensor_operation::element_wise::Add; + using Subtract = ck::tensor_operation::element_wise::Subtract; + + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseBinAdd = GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Add, + BlockSize, + MPerBlock, + NPerBlock, + MPerThread, + NPerThread, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + using GridwiseBinSubtract = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Subtract, + BlockSize, + MPerBlock, + NPerBlock, + MPerThread, + NPerThread, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + const index_t M = arg.c_grid_desc_m_n.GetLength(I0); + const index_t N = arg.c_grid_desc_m_n.GetLength(I1); + const auto block_2_tile_map = Block2TileMap(M, N); + + const auto add_kernel = kernel_elementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Add>; + + const auto subtract_kernel = + kernel_elementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + Subtract>; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_real, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_imag, + arg.p_aux_2_grid, + arg); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel( + stream_config, + subtract_kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_real), + block_2_tile_map, + Subtract{}); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_imag, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_real, + arg.p_aux_2_grid, + arg); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel( + stream_config, + add_kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_imag), + block_2_tile_map, + Add{}); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_real, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_imag, + arg.p_aux_2_grid, + arg); + + // c_real = aux - aux_2 + ave_time += launch_and_time_kernel( + stream_config, + subtract_kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_real), + block_2_tile_map, + Subtract{}); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_real, + arg.p_b_grid_imag, + arg.p_aux_grid, + arg); + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_imag, + arg.p_b_grid_real, + arg.p_aux_2_grid, + arg); + + // c_imag = aux + aux_2 + ave_time += launch_and_time_kernel( + stream_config, + add_kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + make_tuple(arg.c_grid_desc_m_n, arg.c_grid_desc_m_n), + make_tuple(arg.c_grid_desc_m_n), + make_tuple(const_cast(arg.p_aux_grid), + const_cast(arg.p_aux_2_grid)), + make_tuple(arg.p_c_grid_imag), + block_2_tile_map, + Add{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a_real, + const ADataType* p_a_imag, + const BDataType* p_b_real, + const BDataType* p_b_imag, + CDataType* p_c_real, + CDataType* p_c_imag, + CDataType* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a_real, + p_a_imag, + p_b_real, + p_b_imag, + p_c_real, + p_c_imag, + p_workspace, + M, + N, + K, + StrideA, + StrideB, + StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a_real, + const void* p_a_imag, + const void* p_b_real, + const void* p_b_imag, + void* p_c_real, + void* p_c_imag, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + index_t /* KBatch */ = 1) override + { + return std::make_unique(static_cast(p_a_real), + static_cast(p_a_imag), + static_cast(p_b_real), + static_cast(p_b_imag), + static_cast(p_c_real), + static_cast(p_c_imag), + static_cast(p_workspace), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceCGemm_4Gemm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } + + static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N( + M, GridwiseGemm::CalculateMPadded(M), N, GridwiseGemm::CalculateNPadded(N), StrideC); + + return c_grid_desc_m_n.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceSize(index_t M, + index_t N, + [[maybe_unused]] index_t K, + [[maybe_unused]] index_t StrideA, + [[maybe_unused]] index_t StrideB, + index_t StrideC) const override + { + return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC); + } + + std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override + { + const auto* parg = dynamic_cast(base_arg); + + if(!parg) + { + std::ostringstream err; + err << "Provided argument pointer is not of an Argument class!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return GetWorkspaceSize( + parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp new file mode 100755 index 000000000000..9482812f7571 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp @@ -0,0 +1,647 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Column to Image: +// input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C] +// output : input image [G, N, Di, Hi, Wi, C] +// input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C] +// output : input image [N, Di, Hi, Wi, G, C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceColumnToImageImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = NDimSpatial == 1 ? I0 : Number{}; + static constexpr auto XIdx = Number{}; + + static constexpr auto spatial_offset = Number<3>{}; + + using ConvToGemmFwdTransformer = + TransformConvFwdToGemm; + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Calculate number of independent filters for given conv params + static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len, + const index_t left_pad, + const index_t right_pad, + const index_t filter_len, + const index_t filter_stride, + const index_t filter_dilation, + const index_t image_offset) + { + const index_t x_eff = (filter_len - 1) * filter_dilation + 1; + const index_t next_filter_padded = + math::integer_divide_ceil(x_eff, filter_stride) * filter_stride; + // If filter_stride >= x_eff then each filter is independent + const index_t independent_filter_stride = + filter_stride >= x_eff ? filter_stride : next_filter_padded; + const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff; + // There are no independent filters + if(w_eff < 0) + return 0; + const index_t independent_kernels_num = w_eff / independent_filter_stride + 1; + return independent_kernels_num; + } + + // Make column form descriptor + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& conv_filter_strides, + const std::array& gemm_g_m_k_strides, + const std::array& independent_filters, + const std::array& effs) + { + const index_t DoHoWo = ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2]; + // Calculate the appropriate stride for each set of independent filters + // in each dimension + const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) * + gemm_g_m_k_strides[I1]; + const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) * + output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1]; + const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) * + output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] * + gemm_g_m_k_strides[I1]; + // Create descriptor for independent filters in each dimension and + // then merge them into column form + if constexpr(NDimSpatial == 1) + { + const auto desc_gemm_form = + make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX), + make_tuple(NStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 2) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX), + make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform( + make_tuple(N, independent_filters[YIdx], independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + else if constexpr(NDimSpatial == 3) + { + const auto desc_gemm_form = make_naive_tensor_descriptor( + make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx], + CZYX), + make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2])); + const auto desc_gemm_form_merged_filters = transform_tensor_descriptor( + desc_gemm_form, + make_tuple(make_merge_transform(make_tuple(N, + independent_filters[ZIdx], + independent_filters[YIdx], + independent_filters[XIdx])), + make_pass_through_transform(CZYX)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters); + return desc_m_k; + } + } + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const std::array& image_offsets, + const std::array& independent_filters, + const std::array& effs) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + // Calculate descriptor only for independent filters + copy(independent_filters, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + // Modify pads to apply offsets + std::array input_left_pads_with_offset; + for(index_t i = 0; i < NDimSpatial; i++) + { + input_left_pads_with_offset[i] = math::max(0, input_left_pads[i] - image_offsets[i]); + } + // Modify input spatial lengths to apply offsets + for(index_t i = 0; i < NDimSpatial; i++) + { + a_g_n_c_wis_lengths[i + spatial_offset] -= + math::max(0, image_offsets[i] - input_left_pads[i]); + } + + // Strides to next independent filters + std::array independent_filter_strides; + for(index_t i = 0; i < NDimSpatial; i++) + { + index_t independent_filter_stride = + math::integer_divide_ceil(effs[i], conv_filter_strides[i]) * conv_filter_strides[i]; + // If conv stride is greater than whole filter size, use conv stride + independent_filter_strides[i] = conv_filter_strides[i] >= effs[i] + ? conv_filter_strides[i] + : independent_filter_stride; + } + + ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + // conv_filter_strides, + independent_filter_strides, + conv_filter_dilations, + input_left_pads_with_offset, + input_right_pads}; + + // Calculate image form descriptor for the modified convolution problem + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + InputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0]; + + const index_t x_eff = + (filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1; + const index_t y_eff = + NDimSpatial < 2 + ? I1 + : (filter_spatial_lengths[YIdx] - 1) * conv_filter_dilations[YIdx] + 1; + const index_t z_eff = + NDimSpatial < 3 + ? I1 + : (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1; + + // Iterate over sets of independent filters + for(int z_img_offset = 0; z_img_offset < z_eff; + z_img_offset += conv_filter_strides[ZIdx]) + { + for(int y_img_offset = 0; y_img_offset < y_eff; + y_img_offset += conv_filter_strides[YIdx]) + { + for(int x_img_offset = 0; x_img_offset < x_eff; + x_img_offset += conv_filter_strides[XIdx]) + { + + std::array image_offsets; + std::array effs; + // Calculate the starting offset for a given set of + // independent filters + if constexpr(NDimSpatial == 1) + { + image_offsets = {x_img_offset}; + effs = {x_eff}; + } + if constexpr(NDimSpatial == 2) + { + image_offsets = {y_img_offset, x_img_offset}; + effs = {y_eff, x_eff}; + } + else if constexpr(NDimSpatial == 3) + { + image_offsets = {z_img_offset, y_img_offset, x_img_offset}; + effs = {z_eff, y_eff, x_eff}; + } + + std::array independent_filters; + for(index_t i = 0; i < NDimSpatial; i++) + { + independent_filters[i] = + GetNumberOfIndependentFilters(input_spatial_lengths[i], + input_left_pads[i], + input_right_pads[i], + filter_spatial_lengths[i], + conv_filter_strides[i], + conv_filter_dilations[i], + image_offsets[i]); + } + const index_t independent_filters_acum = ck::accumulate_n( + independent_filters.begin(), NDimSpatial, 1, std::multiplies<>()); + if(independent_filters_acum <= 0) + continue; + + const auto in_grid_desc_m_k = + MakeInputDescriptor_M_K(N, + C, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + gemm_g_m_k_strides, + independent_filters, + effs); + const auto out_grid_desc_m_k = + MakeOutDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + image_offsets, + independent_filters, + effs); + in_grid_desc_m_k_container_.push_back(in_grid_desc_m_k); + out_grid_desc_m_k_container_.push_back(out_grid_desc_m_k); + + const index_t x_idx = x_img_offset / conv_filter_strides[XIdx]; + const index_t y_idx = y_img_offset / conv_filter_strides[YIdx]; + const index_t z_idx = z_img_offset / conv_filter_strides[ZIdx]; + + const index_t x_offset_with_pad = + math::max(0, x_img_offset - input_left_pads[XIdx]); + const index_t y_offset_with_pad = + math::max(0, y_img_offset - input_left_pads[YIdx]); + const index_t z_offset_with_pad = + math::max(0, z_img_offset - input_left_pads[ZIdx]); + + // Memory offsets to next set of independent filters, + // move to independent filters in each dimension + const index_t in_offset = + (x_idx + y_idx * output_spatial_lengths[XIdx] + + z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) * + gemm_g_m_k_strides[I1]; + // Move to independent filters in appropriate dimensions + const index_t out_offset = + x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] + + y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] + + z_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + ZIdx]; + + const InputDataType* p_in_with_offset = + static_cast(p_in) + in_offset; + OutputDataType* p_out_with_offset = + static_cast(p_out) + out_offset; + p_in_container_.push_back(p_in_with_offset); + p_out_container_.push_back(p_out_with_offset); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < in_grid_desc_m_k_container_.size(); i++) + { + std::cout << in_grid_desc_m_k_container_[i] << std::endl; + std::cout << out_grid_desc_m_k_container_[i] << std::endl; + } + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + std::vector in_grid_desc_m_k_container_; + std::vector out_grid_desc_m_k_container_; + + std::vector p_in_container_; + std::vector p_out_container_; + + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float elapsed_time = 0.f; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + // Execute each set of independent filters + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_container_[i]); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_; + elapsed_time += launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_container_[i], + arg.p_in_container_[i], + arg.out_grid_desc_m_k_container_[i], + arg.p_out_container_[i], + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + } + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + using namespace tensor_layout::convolution; + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + bool valid = true; + for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++) + { + valid &= GridwiseTensorRearrangeKernel::CheckValidity( + arg.in_grid_desc_m_k_container_[i], arg.out_grid_desc_m_k_container_[i]); + } + return valid; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // output image + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceColumnToImage" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp new file mode 100755 index 000000000000..df5922a04f01 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,847 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = as_grid_desc_ak0_m_ak1; + ignore = bs_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceContractionMultipleABD_Xdl_CShuffle + : public DeviceContractionMultipleABD +{ + using DeviceOp = DeviceContractionMultipleABD_Xdl_CShuffle; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ComputeDataType = EDataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_, + const std::vector& a_ms_ks_strides_) + { + assert(a_ms_ks_lengths_.size() == NumDimM + NumDimK && + a_ms_ks_strides_.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + __host__ __device__ static auto + MakeAsGridDescriptor_M_K(const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_M_K(as_ms_ks_lengths[i], as_ms_ks_strides[i]); + }, + Number{}); + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_, + const std::vector& b_ns_ks_strides_) + { + assert(b_ns_ks_lengths_.size() == NumDimN + NumDimK && + b_ns_ks_strides_.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + __host__ __device__ static auto + MakeBsGridDescriptor_N_K(const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_N_K(bs_ns_ks_lengths[i], bs_ns_ks_strides[i]); + }, + Number{}); + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_, + const std::vector& e_ms_ns_strides_) + { + assert(e_ms_ns_lengths_.size() == NumDimM + NumDimN && + e_ms_ns_strides_.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto + MakeDsGridDescriptor_M_N(const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + }, + Number{}); + } + + // desc for problem definition + using AsGridDesc_M_K = remove_cvref_t; + using BsGridDesc_N_K = remove_cvref_t; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = remove_cvref_t; + + // desc for blockwise copy + using AsGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BsGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::array p_as_grid, + std::array p_bs_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + as_grid_desc_m_k_{}, + bs_grid_desc_n_k_{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{MakeEGridDescriptor_M_N(e_ms_ns_length, e_ms_ns_stride)}, + as_grid_desc_ak0_m_ak1_{}, + bs_grid_desc_bk0_n_bk1_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + // using ALayout = remove_cvref_t>; + using ADataType = remove_cvref_t>; + + // A pointer + p_as_grid_(i) = static_cast(p_as_grid[i]); + + // A desc + as_grid_desc_m_k_(i) = + MakeAGridDescriptor_M_K(a_ms_ks_lengths[i], a_ms_ks_strides[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + // using BLayout = remove_cvref_t>; + using BDataType = remove_cvref_t>; + + // B pointer + p_bs_grid_(i) = static_cast(p_bs_grid[i]); + + // B desc + bs_grid_desc_n_k_(i) = + MakeBGridDescriptor_N_K(b_ns_ks_lengths[i], b_ns_ks_strides[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + // using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + MakeEGridDescriptor_M_N(d_ms_ns_lengths[i], d_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(as_grid_desc_m_k_, + bs_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + as_grid_desc_ak0_m_ak1_ = + GridwiseGemm::MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k_); + + bs_grid_desc_bk0_n_bk1_ = + GridwiseGemm::MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + + // for sanity check of vector memory access + for(index_t i = 0; i < NumATensor; ++i) + { + tie(as_continous_dim_[i], as_max_read_elems_[i]) = + CalculateMaxRead(a_ms_ks_lengths[i], a_ms_ks_strides[i]); + } + + for(index_t i = 0; i < NumBTensor; ++i) + { + tie(bs_continous_dim_[i], bs_max_read_elems_[i]) = + CalculateMaxRead(b_ns_ks_lengths[i], b_ns_ks_strides[i]); + } + + for(index_t i = 0; i < NumDTensor; ++i) + { + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = + CalculateMaxRead(d_ms_ns_lengths[i], d_ms_ns_strides[i]); + } + + tie(e_continous_dim_, e_max_write_elems_) = + CalculateMaxRead(e_ms_ns_length, e_ms_ns_stride); + } + + // pointers + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AsGridDesc_M_K as_grid_desc_m_k_; + BsGridDesc_N_K bs_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1_; + BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + std::array as_continous_dim_; + std::array bs_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; + + std::array as_max_read_elems_; + std::array bs_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_abd_xdl_cshuffle< + GridwiseGemm, + typename GridwiseGemm::AsGridPointer, + typename GridwiseGemm::BsGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AsGridDesc_AK0_M_AK1, + DeviceOp::BsGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.as_grid_desc_ak0_m_ak1_, + arg.bs_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.as_grid_desc_m_k_[I0].GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + bool valid_as_access = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + const bool valid_a_vector_size = + arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.as_continous_dim_[i] == 0; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.as_continous_dim_[i] == 1; + const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k; + if(!((valid_a_vector_size && valid_a_access_dim) || + ABlockTransferSrcScalarPerVector == 1)) + { + valid_as_access = false; + } + }); + if(!valid_as_access) + { + return false; + } + + bool valid_bs_access = true; + static_for<0, NumBTensor, 1>{}([&](auto i) { + const bool valid_b_vector_size = + arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.bs_continous_dim_[i] == 0; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.bs_continous_dim_[i] == 1; + const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k; + if(!((valid_b_vector_size && valid_b_access_dim) || + BBlockTransferSrcScalarPerVector == 1)) + { + valid_bs_access = false; + } + }); + if(!valid_bs_access) + { + return false; + } + + bool valid_ds_access = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = arg.ds_continous_dim_[i] == 1; + if(!((valid_d_vector_size && valid_d_access_dim) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + valid_ds_access = false; + } + }); + if(!valid_ds_access) + { + return false; + } + + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = arg.e_continous_dim_ == 1; + if(!((valid_e_vector_size && valid_e_access_dim) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.as_grid_desc_m_k_, + arg.bs_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& a_ms_ks_lengths, + const std::array, NumATensor>& a_ms_ks_strides, + const std::array, NumBTensor>& b_ns_ks_lengths, + const std::array, NumBTensor>& b_ns_ks_strides, + const std::array, NumDTensor>& d_ms_ns_lengths, + const std::array, NumDTensor>& d_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + d_ms_ns_lengths, + d_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + const std::array, NumATensor>& as_ms_ks_lengths, + const std::array, NumATensor>& as_ms_ks_strides, + const std::array, NumBTensor>& bs_ns_ks_lengths, + const std::array, NumBTensor>& bs_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_length, + const std::vector& e_ms_ns_stride, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + as_ms_ks_lengths, + as_ms_ks_strides, + bs_ns_ks_lengths, + bs_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_length, + e_ms_ns_stride, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceContractionMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..77974f84aeeb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,780 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceContractionMultipleD_Xdl_CShuffle + : public DeviceContractionMultipleD +{ + using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_vec, + const std::vector& a_ms_ks_strides_vec) + { + assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK && + a_ms_ks_strides_vec.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_vec, + const std::vector& b_ns_ks_strides_vec) + { + assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK && + b_ns_ks_strides_vec.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_vec, + const std::vector& e_ms_ns_strides_vec) + { + assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN && + e_ms_ns_strides_vec.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i], + ds_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_ms_ks_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + // for sanity check of vector memory access + tie(a_continous_dim_, a_max_read_elems_) = + CalculateMaxRead(a_ms_ks_lengths, a_ms_ks_strides); + + tie(b_continous_dim_, b_max_read_elems_) = + CalculateMaxRead(b_ns_ks_lengths, b_ns_ks_strides); + + for(index_t i = 0; i < NumDTensor; ++i) + { + tie(ds_continous_dim_[i], ds_max_read_elems_[i]) = + CalculateMaxRead(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]); + } + + tie(e_continous_dim_, e_max_write_elems_) = + CalculateMaxRead(e_ms_ns_lengths, e_ms_ns_strides); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Describe whether the last part of a given dimension of A/B/D/E is continues dim. + index_t a_continous_dim_; + index_t b_continous_dim_; + std::array ds_continous_dim_; + index_t e_continous_dim_; + + index_t a_max_read_elems_; + index_t b_max_read_elems_; + std::array ds_max_read_elems_; + index_t e_max_write_elems_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!ck::is_lds_direct_load_supported() && std::is_same::value) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + const bool valid_a_vector_size = + arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0; + const bool valid_a_access_dim_m = + ABlockTransferSrcVectorDim == 1 && arg.a_continous_dim_ == 0; + const bool valid_a_access_dim_k = + ABlockTransferSrcVectorDim == 2 && arg.a_continous_dim_ == 1; + const bool valid_a_access_dim = + valid_a_access_dim_m || valid_a_access_dim_k || ABlockTransferSrcScalarPerVector == 1; + if(!(valid_a_vector_size && valid_a_access_dim)) + { + return false; + } + + const bool valid_b_vector_size = + arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0; + const bool valid_b_access_dim_n = + BBlockTransferSrcVectorDim == 1 && arg.b_continous_dim_ == 0; + const bool valid_b_access_dim_k = + BBlockTransferSrcVectorDim == 2 && arg.b_continous_dim_ == 1; + const bool valid_b_access_dim = + valid_b_access_dim_n || valid_b_access_dim_k || BBlockTransferSrcScalarPerVector == 1; + if(!(valid_b_vector_size && valid_b_access_dim)) + { + return false; + } + + bool valid_ds_access = true; + static_for<0, NumDTensor, 1>{}([&](auto i) { + const bool valid_d_vector_size = + arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector read of Ds is always on N dimension. + const bool valid_d_access_dim = + arg.ds_continous_dim_[i] == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; + if(!(valid_d_vector_size && valid_d_access_dim)) + { + valid_ds_access = false; + } + }); + if(!valid_ds_access) + { + return false; + } + + const bool valid_e_vector_size = + arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0; + // Vector write of E is always on N dimension. + const bool valid_e_access_dim = + arg.e_continous_dim_ == 1 || CDEBlockTransferScalarPerVector_NPerBlock == 1; + if(!(valid_e_vector_size && valid_e_access_dim)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_ms_ks_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_ms_ks_lengths, + const std::vector& a_ms_ks_strides, + const std::vector& b_ns_ks_lengths, + const std::vector& b_ns_ks_strides, + const std::array, NumDTensor>& ds_ms_ns_lengths, + const std::array, NumDTensor>& ds_ms_ns_strides, + const std::vector& e_ms_ns_lengths, + const std::vector& e_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_ms_ks_lengths, + a_ms_ks_strides, + b_ns_ks_lengths, + b_ns_ks_strides, + ds_ms_ns_lengths, + ds_ms_ns_strides, + e_ms_ns_lengths, + e_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp new file mode 100755 index 000000000000..1b0db73fdd09 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * Calculates the maximum number of subsequent elements of the fast changing dimension + * that are consecutive in memory. + * + * Example: + * NumDimM = 2, NumDimK = 3 + * A shape = [ 2, 3, 4, 5, 6] + * A strides = [360, 120, 30, 6, 1] + * | M | | K | + * It follows from strides that K is FCD and all the subsequent elements of K are consecutive + * in memory. + * But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be + * consecutive in memory. + * + * Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions. + */ +template +auto CalculateMaxRead(const std::vector& lengths, const std::vector& strides) +{ + if(lengths.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of lengths in " + << "device_contraction_utils.hpp" + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + if(strides.size() != NumDim1 + NumDim2) + { + std::ostringstream err; + err << "Incorrect number of strides in " + << "device_contraction_utils.hpp" + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // Determine the beginning and end idx of the group representing the FCD. + index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1; + if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1) + { + // MZ or KZ are ones + bool dims1_are_ones = true; + for(index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++) + { + if(lengths[dim_idx] != 1) + { + dims1_are_ones = false; + } + } + + if(dims1_are_ones) + { + begin_idx = NumDim1; + end_idx = NumDim1 + NumDim2 - 1; + continous_dim = 1; + } + else + { + begin_idx = 0; + end_idx = NumDim1 - 1; + continous_dim = 0; + } + } + else if(strides[NumDim1 - 1] == 1) + { + begin_idx = 0; + end_idx = NumDim1 - 1; + continous_dim = 0; + } + else if(strides[NumDim1 + NumDim2 - 1] == 1) + { + begin_idx = NumDim1; + end_idx = NumDim1 + NumDim2 - 1; + continous_dim = 1; + } + else + { + // The dimension consecutive in memory is not the last dimension of any group, so only + // one element can be read/written at once. + consecutive_stride = 1; + continous_dim = 0; + return make_tuple(continous_dim, consecutive_stride); + } + + for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx) + { + if(strides[dim_idx] == consecutive_stride) + { + consecutive_stride *= lengths[dim_idx]; + } + else + { + break; + } + } + const index_t max_subsequent_elems = consecutive_stride; + return make_tuple(continous_dim, max_subsequent_elems); +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..23440e24f682 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,808 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdWeight<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + static constexpr ck::index_t NDimSpatial = 2; + + using DeviceOp = + DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static constexpr auto N1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + // A: output tensor + const index_t N0 = N / N1Number; + const index_t GemmK0Total = N0 * Ho * Wo; + + const index_t GemmK0S = + math::integer_divide_ceil(GemmK0Total, K0PerBlock * GemmKBatch) * K0PerBlock; + const index_t GemmK0Pad = GemmKBatch * GemmK0S; + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho * Wo, K)); + + const auto out_n0_ho_wo_k_n1_grid_desc = + transform_tensor_descriptor(out_n_ho_wo_k_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0total_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_n0_ho_wo_k_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho * Wo)), + make_pass_through_transform(K), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmk0pad_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0total_gemmm_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmk0pad_gemmm_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmM), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n0_y_ho_x_wo_c_n1_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(N0, N1Number)), + make_pass_through_transform(Y), + make_pass_through_transform(Ho), + make_pass_through_transform(X), + make_pass_through_transform(Wo), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 6>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_gemmk0total_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_n0_y_ho_x_wo_c_n1_grid_desc, + make_tuple(make_merge_transform(make_tuple(N0, Ho, Wo)), + make_merge_transform(make_tuple(Y, X, C)), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmk0pad_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0total_gemmn_gemmk1_grid_desc, + make_tuple(make_right_pad_transform(GemmK0Total, GemmK0Pad - GemmK0Total), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0pad_gemmn_gemmk1_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0)), + make_pass_through_transform(GemmN), + make_pass_through_transform(N1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + + using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true>; + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + output_spatial_lengths_{output_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation a_element_op_; + OutElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::array output_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void Print(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(arg); + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight has invalid setting"); + } + const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + hipGetErrorString(hipMemsetAsync( + arg.p_c_grid_, + 0, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + true>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_bwd_weight< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + false>; + + Run(kernel); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // unmerge N to N0 and N1, where N1 equals to K1 + if(!(arg.Conv_N_ % K1 == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::array input_spatial_lengths, + std::array filter_spatial_lengths, + std::array output_spatial_lengths, + std::array conv_filter_strides, + std::array conv_filter_dilations, + std::array input_left_pads, + std::array input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..d4f89b3e09d7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,767 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvBwdData<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t i_ytilde, + index_t i_xtilde) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 0, 0)); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + i_ytilde, + i_xtilde); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + } + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + else + { + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..a8eb73d730df --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,983 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation_add.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivationAdd +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + // C1: residual tensor: assume same layout as output tensor + const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc; + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn, + resi_grid_desc_gemmm_gemmn); + } + } + + using GridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + using C1GridDesc_M_N = remove_cvref_t; + + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + p_c1_grid_{p_resi_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c1_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + c1_grid_desc_m_n_ = descs[I4]; + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c1_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + const CDataType* p_c1_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.p_c1_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + const OutDataType* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + p_resi_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + const void* p_resi_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + static_cast(p_resi_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..6eb9281d30c9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,940 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd_bias_activation.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = +// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + InMemoryDataOperationEnum OutGlobalMemoryDataOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXDL, + ck::index_t NPerXDL, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwdBiasActivation +{ + using DeviceOp = + DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + // TODO make it support any # of spatial dimensions + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // C0: bias tensor: assume a contiguous vector + const auto bias_grid_desc_gemmm_gemmn = + make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1)); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc, + bias_grid_desc_gemmm_gemmn); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using C0GridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + OutGlobalMemoryDataOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + C0GridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + ck::index_t M01, + ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + p_c0_grid_{p_bias_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c0_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + c0_grid_desc_m_n_ = descs[I3]; + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c0_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const CDataType* p_c0_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0) + << ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r2< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + remove_reference_t< + typename GridwiseGemm:: + C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + const OutDataType* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_bias_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + const void* p_bias_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + static_cast(p_bias_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..5fad21f521eb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,906 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template < + typename InDataType, + typename WeiDataType, + typename OutDataType, + typename AccDataType, + typename InElementwiseOperation, + typename WeiElementwiseOperation, + typename OutElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + ck::index_t BlockSize, + ck::index_t MPerBlock, + ck::index_t NPerBlock, + ck::index_t K0PerBlock, + ck::index_t K1, + ck::index_t MPerXdl, + ck::index_t NPerXdl, + ck::index_t MXdlPerWave, + ck::index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + ck::index_t ABlockTransferSrcVectorDim, + ck::index_t ABlockTransferSrcScalarPerVector, + ck::index_t ABlockTransferDstScalarPerVector_K1, + bool ABlockLdsAddExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + ck::index_t BBlockTransferSrcVectorDim, + ck::index_t BBlockTransferSrcScalarPerVector, + ck::index_t BBlockTransferDstScalarPerVector_K1, + bool BBlockLdsAddExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl> +struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + + const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock); + const auto GemmMPad = GemmM - GemmMRaw; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { // 1x1, stride=1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { // 1x1, pad=0 + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::OddC) + { // C = odd value + const index_t GemmKRaw = Y * X * C; + const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number); + const index_t GemmKPad = GemmK - GemmKRaw; + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmkraw_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmkraw_gemmmraw_grid_desc, + make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), + make_right_pad_transform(GemmKRaw, GemmKPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + const index_t GemmK = Y * X * C; + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using Block2CTileMap = BlockToCTileMap_M00_N0_M01; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, // TODO: Add ShuffleType for DeviceConv2d + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock * K1, + K1, // AK1 + K1, // BK1 + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + CBlockTransferScalarPerVector_NWaveNPerXdl>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{}, + block_2_ctile_map_{}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = Block2CTileMap{c_grid_desc_m_n_}; + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ = + GridwiseGemm:: + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_; + Block2CTileMap block_2_ctile_map_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << DeviceOp{}.GetTypeString() << std::endl; + std::cout << "N " << arg.Conv_N_ << ", " + << "K " << arg.Conv_K_ << ", " + << "C " << arg.Conv_C_ << ", " << std::endl; + std::cout << "Y X " << arg.filter_spatial_lengths_[0] << ", " + << arg.filter_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Hi Wi " << arg.input_spatial_lengths_[0] << ", " + << arg.input_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Ho Wo " << arg.output_spatial_lengths_[0] << ", " + << arg.output_spatial_lengths_[1] << ", " << std::endl; + std::cout << "Strides " << arg.conv_filter_strides_[0] << ", " + << arg.conv_filter_strides_[1] << ", " << std::endl; + std::cout << "Dilations " << arg.conv_filter_dilations_[0] << ", " + << arg.conv_filter_dilations_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_left_pads_[0] << ", " + << arg.input_left_pads_[1] << ", " << std::endl; + std::cout << "InLeftPads " << arg.input_right_pads_[0] << ", " + << arg.input_right_pads_[1] << ", " << std::endl; + + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout + << "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_" + "nwavenperxdl_{ " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I0) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I1) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I2) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I3) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I4) + << ", " + << arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ + .GetLength(I5) + << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + true>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v3r1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm:: + CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + Block2CTileMap, + false>; + + ave_time = launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..c7aa54f1d9c5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,680 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K + : public DeviceConvFwd<2, + ck::tensor_layout::convolution::NHWC, + ck::tensor_layout::convolution::KYXC, + ck::tensor_layout::convolution::NHWK, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr index_t NDimSpatial = 2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmMRaw = N * Ho * Wo; + const index_t GemmN = K; + const index_t GemmK = Y * X * C; + + const auto GemmMPad = math::integer_least_multiple(GemmMRaw, MPerBlock) - GemmMRaw; + + assert(GemmK % GemmK1Number == 0); + + const index_t GemmK0 = GemmK / GemmK1Number; + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // A: input tensor + const auto in_gemmmraw_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C)); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmmraw_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_right_pad_transform(GemmMRaw, GemmMPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)), + make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_n_ho_wo_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_gemmn_gemmk_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, C)); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmn_gemmk_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_gemmmraw_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + else + { + // A: input tensor + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmmraw_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk_gemmmraw_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmMRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmk0_gemmmraw_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B: weight tensor + const auto wei_k_yxc_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + wei_k_yxc_grid_desc, + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: output tensor + const auto out_nhowo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmmraw_gemmn_grid_desc = + transform_tensor_descriptor(out_nhowo_k_grid_desc, + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmm_gemmn_grid_desc = + transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + out_gemmm_gemmn_grid_desc); + } + } + + using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1})); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, // ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder, + 2, // BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_a_grid_{p_in_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_out_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + filter_spatial_lengths_{filter_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + std::vector filter_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_); + } + else + { + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 && + arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 && + arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + const void* p_wei_grid, + void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp new file mode 100755 index 000000000000..cd89f3232cb9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_naive_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,268 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef DEVICE_CONV3D_FWD_NAIVE_HPP +#define DEVICE_CONV3D_FWD_NAIVE_HPP + +#include +#include +#include +#include "conv_util.hpp" +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "naive_conv_fwd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : params_{3, + N, + K, + C, + filter_spatial_lengths, + input_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + out_spatial_lengths_{output_spatial_lengths}, + p_in_{p_in}, + p_wei_{p_wei}, + p_out_{p_out}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + + { + } + + // private: + utils::conv::ConvParams params_; + std::vector out_spatial_lengths_; + + const InDataType* p_in_; + const WeiDataType* p_wei_; + OutDataType* p_out_; + + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto naive_conv3d_fwd = + ref::naive_conv_fwd_ndhwc_kzyxc_ndhwk; + + float ave_time = launch_and_time_kernel(stream_config, + naive_conv3d_fwd, + dim3(256), + dim3(256), + 0, + arg.p_in_, + arg.p_wei_, + arg.p_out_, + arg.N_, + arg.K_, + arg.C_, + arg.in_spatial_lengths_[0], + arg.in_spatial_lengths_[1], + arg.in_spatial_lengths_[2], + arg.filter_spatial_lengths_[0], + arg.filter_spatial_lengths_[1], + arg.filter_spatial_lengths_[2], + arg.out_spatial_lengths_[0], + arg.out_spatial_lengths_[1], + arg.out_spatial_lengths_[2], + arg.conv_filter_strides_[0], + arg.conv_filter_strides_[1], + arg.conv_filter_strides_[2], + arg.conv_filter_dilations_[0], + arg.conv_filter_dilations_[1], + arg.conv_filter_dilations_[2], + arg.in_left_pads_[0], + arg.in_left_pads_[1], + arg.in_left_pads_[2]); + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + std::vector out_spatial_lengths = arg.params_.GetOutputSpatialLengths(); + + bool out_lengths_are_consistent = out_spatial_lengths[0] == arg.out_spatial_lengths_[0] && + out_spatial_lengths[1] == arg.out_spatial_lengths_[1] && + out_spatial_lengths[2] == arg.out_spatial_lengths_[2]; + return out_lengths_are_consistent; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K<>"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp new file mode 100755 index 000000000000..68ec8187a4d7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp @@ -0,0 +1,661 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef DEVICE_CONV3D_FWD_XDL_HPP +#define DEVICE_CONV3D_FWD_XDL_HPP + +#include +#include +#include +#include "device.hpp" +#include "device_conv_fwd.hpp" +#include "common_header.hpp" +#include "ck/utility/env.hpp" +#include "tensor_layout.hpp" +#include "convolution_forward_specialization.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk.hpp" +#include "gridwise_gemm_xdlops_v2r3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/* + * \see \link impl/device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink. + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r3_for_conv3d( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t num_batches, + const index_t a_batch_stride, + const index_t b_batch_stride, + const index_t c_batch_stride, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(a_batch_stride) * g_idx); + const long_index_t b_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(b_batch_stride) * g_idx); + const long_index_t c_batch_offset = + __builtin_amdgcn_readfirstlane(static_cast(c_batch_stride) * g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = num_batches; + ignore = a_batch_stride; + ignore = b_batch_stride; + ignore = c_batch_stride; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx9__)) +} + +// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k] +template +struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K + : public DeviceConvFwd + +{ + using DeviceOp = DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K; + + using ADataType = InDataType; + using BDataType = WeiDataType; + using CDataType = OutDataType; + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + /* + * \brief Split the number of batches, \p N, into N = B * N1, such that the memory + * space of input and output tensors stays with the value range of index_t, and each subbatch + * can be dealed with GridwiseGemm. + */ + static index_t GetMaxAllowableSubBatchSize(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector output_spatial_lengths) + { + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + // N1 should satisfy that + // 1) N % N1 = 0; + // 2) N1 * (Do * Ho * Wo * K) < (2^31 - 1) + // 3) N1 * (Di * Hi * Wi * C) < (2^31 - 1) + // + // Do NOT confuse (B, N1) in this function with (B, N1) in gridewise GEMM. + auto N1 = N + 1; + + const auto stride = + math::max(long_index_t(Do) * Ho * Wo * K, long_index_t(Di) * Hi * Wi * C); + const index_t max_stride = NumericLimits::Max(); + + for(index_t n0 = 1; n0 <= N; ++n0) + { + index_t n1 = N / n0; + if(n0 * n1 == N && long_index_t(n1) * long_index_t(stride) < max_stride) + { + N1 = n1; + break; + } + } + + const auto B = N / N1; + if(B * N1 != N) + { + throw std::runtime_error(__func__ + + std::string(": failed to find num_subbatches for conv3d.\n")); + } + + return N1; + } + + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + assert(input_spatial_lengths.size() > 2); + assert(filter_spatial_lengths.size() > 2); + assert(conv_filter_strides.size() > 2); + assert(conv_filter_dilations.size() > 2); + assert(input_left_pads.size() > 2); + assert(input_right_pads.size() > 2); + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + static_assert(ConvForwardSpecialization == ConvolutionForwardSpecialization::Default, + "Wrong! This specialization not implemented!"); + + const auto in_desc_n_di_hi_wi_c = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + const auto wei_desc_k_z_y_x_c = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto out_desc_n_do_ho_wo_k = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + + const auto descs = transform_forward_convolution3d_into_gemm_v4r4r4_ndhwc_kzyxc_ndhwk_pad( + in_desc_n_di_hi_wi_c, + wei_desc_k_z_y_x_c, + out_desc_n_do_ho_wo_k, + make_tuple(conv_filter_strides[0], conv_filter_strides[1], conv_filter_strides[2]), + make_tuple( + conv_filter_dilations[0], conv_filter_dilations[1], conv_filter_dilations[2]), + make_tuple(input_left_pads[0], input_left_pads[1], input_left_pads[2]), + make_tuple(input_right_pads[0], input_right_pads[1], input_right_pads[2]), + Number{}); + + return descs; + } + + using ABCGridDescs = remove_cvref_t; + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + InDataType, + AccDataType, + OutDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder, + Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + CThreadTransferDstScalarPerVector>; + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + index_t M01, + index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_in}, + p_b_grid_{p_wei}, + p_c_grid_{p_out}, + M01_{M01}, + N01_{N01}, + in_element_op_{in_element_op}, + wei_element_op_{wei_element_op}, + out_element_op_{out_element_op} + { + const index_t subbatch_size = + GetMaxAllowableSubBatchSize(N, K, C, input_spatial_lengths, output_spatial_lengths); + num_subbatches_ = N / subbatch_size; + + const auto descs = + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(subbatch_size, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + a_batch_stride_ = a_grid_desc_k0_m_k1_.GetElementSpaceSize(); + b_batch_stride_ = 0; + c_batch_stride_ = c_grid_desc_m_n_.GetElementSpaceSize(); + + if(GridwiseGemm::CheckValidity(a_grid_desc_k0_m_k1_, + b_grid_desc_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + } + } + + // private: + const InDataType* p_a_grid_; + const WeiDataType* p_b_grid_; + OutDataType* p_c_grid_; + index_t num_subbatches_; + index_t a_batch_stride_; + index_t b_batch_stride_; + index_t c_batch_stride_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + Block2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + InElementwiseOperation in_element_op_; + WeiElementwiseOperation wei_element_op_; + OutElementwiseOperation out_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "num_batches_of_GEMM = " << arg.num_subbatches_ << std::endl; + std::cout << "a_grid_desc_k0_m_k1{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "b_grid_desc_k0_n_k1{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "c_grid_desc_m_n{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * + arg.num_subbatches_; + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3_for_conv3d< + GridwiseGemm, + InDataType, + OutDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.num_subbatches_, + arg.a_batch_stride_, + arg.b_batch_stride_, + arg.c_batch_stride_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.in_element_op_, + arg.wei_element_op_, + arg.out_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const InDataType* p_in, + const WeiDataType* p_wei, + OutDataType* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in, + p_wei, + p_out, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_in, + const void* p_wei, + void* p_out, + const index_t N, + const index_t K, + const index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + + { + return std::make_unique(static_cast(p_in), + static_cast(p_wei), + static_cast(p_out), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp new file mode 100755 index 000000000000..ddabd61c3d8f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp @@ -0,0 +1,1587 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvNdBwdDataNwcKxcNwk_Dl + : public DeviceConvBwdData< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Dl; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_xtilde = tildes[0]; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + + const auto K0 = K / K1; + + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)), + make_tuple(make_pass_through_transform(N * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto out_n_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K)); + const auto wei_k_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X, C)); + + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_n_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + const index_t i_ztilde = tildes[0]; + const index_t i_ytilde = tildes[1]; + const index_t i_xtilde = tildes[2]; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const auto K0 = K / K1; + + const auto out_n_do_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + const auto wei_k_z_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7, 8>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_k_z_y_x_c_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<5>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + CreateABCDesc(); + } + + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideW = conv_filter_strides_[0]; + const index_t ConvDilationW = conv_filter_dilations_[0]; + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t X = filter_spatial_lengths_[0]; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2])) + { + a_grid_desc_k0_m0_m1_k1_container_.push_back( + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0])); + b_grid_desc_k0_n0_n1_k1_container_.push_back( + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1])); + c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2])); + + block_2_ctile_map_container_.push_back( + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2])); + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2])) + { + a_grid_desc_k0_m0_m1_k1_container_.push_back( + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0])); + b_grid_desc_k0_n0_n1_k1_container_.push_back( + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1])); + c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2])); + + block_2_ctile_map_container_.push_back( + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2])); + } + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideD = conv_filter_strides_[0]; + const index_t ConvStrideH = conv_filter_strides_[1]; + const index_t ConvStrideW = conv_filter_strides_[2]; + + const index_t ConvDilationD = conv_filter_dilations_[0]; + const index_t ConvDilationH = conv_filter_dilations_[1]; + const index_t ConvDilationW = conv_filter_dilations_[2]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Z = filter_spatial_lengths_[0]; + const index_t Y = filter_spatial_lengths_[1]; + const index_t X = filter_spatial_lengths_[2]; + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ztilde, i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + + if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2])) + { + a_grid_desc_k0_m0_m1_k1_container_.push_back( + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(descs[I0])); + b_grid_desc_k0_n0_n1_k1_container_.push_back( + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(descs[I1])); + c_grid_desc_m0_m10_m11_n0_n10_n11_container_.push_back( + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(descs[I2])); + + block_2_ctile_map_container_.push_back( + GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2])); + } + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + + std::vector a_grid_desc_k0_m0_m1_k1_container_; + std::vector b_grid_desc_k0_n0_n1_k1_container_; + std::vector c_grid_desc_m0_m10_m11_n0_n10_n11_container_; + + std::vector block_2_ctile_map_container_; + + // element-wise op + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "arg.a_grid_desc_k0_m_k1_container_{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_container_{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n_container_{ " + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I0) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I1) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I2) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I3) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I4) + << ", " + << arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i].GetLength(I5) + << " ) " << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.c_grid_desc_m_n_container_[i]); + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop; + + const auto kernel = kernel_gemm_dl_v1r3< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + has_main_loop, + has_double_loop>; + + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_container_[i], + arg.b_grid_desc_k0_n0_n1_k1_container_[i], + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_[i], + arg.block_2_ctile_map_container_[i]); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_container_[i].GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + launch_kernel(integral_constant{}, integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || + ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return false; + } + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // matrix A + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t K = arg.Conv_K_; + + if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + + // matrix B + { + auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{}; + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1) + { + return false; + } + if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 || + srcLoadLenghts[I2] % srcVectorLengths[I2] != 0) + { + return false; + } + + const index_t C = arg.Conv_K_; + + if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0) + { + return false; + } + } + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = " + << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl; + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConvNdBwdDataNwcKxcNwk_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 + << ">"; + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ + + str<< " Filter1x1Stride1Pad0"; + } + + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp new file mode 100755 index 000000000000..2881036beede --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp @@ -0,0 +1,1475 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_bwd_data.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceConvNdBwdDataNwcKxcNwk_Xdl + : public DeviceConvBwdData< + NDimSpatial, + ck::tuple_element_t>, + ck::tuple_element_t>, + ck::tuple_element_t>, + InDataType, + WeiDataType, + OutDataType, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation> +{ + using DeviceOp = DeviceConvNdBwdDataNwcKxcNwk_Xdl; + + using ADataType = OutDataType; + using BDataType = WeiDataType; + using CDataType = InDataType; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[I2]) % + ABlockTransferSrcScalarPerVector == + 0); + static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[I1]) % + BBlockTransferSrcScalarPerVector == + 0); + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_xtilde = tildes[0]; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + + const auto K0 = K / K1; + + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)), + make_tuple(make_pass_through_transform(N * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto out_n_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wo, K)); + const auto wei_k_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X, C)); + + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_n_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<>{}, Sequence<3>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 0>{}, Sequence<3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + index_t i_ytilde = tildes[0]; + index_t i_xtilde = tildes[1]; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const auto K0 = K / K1; + + const auto out_n_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K)); + const auto wei_k_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C)); + const auto in_n_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_k_y_x_c_grid_desc, + make_tuple(make_pass_through_transform(K), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto + MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + std::vector tildes) + { + using namespace ck; + + const index_t i_ztilde = tildes[0]; + const index_t i_ytilde = tildes[1]; + const index_t i_xtilde = tildes[2]; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const auto K0 = K / K1; + + const auto out_n_do_ho_wo_k_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K)); + const auto wei_k_z_y_x_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C)); + const auto in_n_di_hi_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Di, Hi, Wi, C)); + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Do * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Do, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + else + { + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const auto ZDot = math::integer_divide_ceil(Z, ZTilde); + const auto YDot = math::integer_divide_ceil(Y, YTilde); + const auto XDot = math::integer_divide_ceil(X, XTilde); + + const auto DTilde = + Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD); + const auto HTilde = + Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH); + const auto WTilde = + Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW); + + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW); + + const auto IDTildeSliceEnd = math::min( + DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + // A: output tensor + const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_n_do_ho_wo_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Do, I0, I0), + make_pad_transform(Ho, I0, I0), + make_pad_transform(Wo, I0, I0), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_dop_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(ZDot, DTilde), + make_tuple(-ConvDilationD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, HTilde), + make_tuple(-ConvDilationH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, WTilde), + make_tuple(-ConvDilationW / GcdStrideDilationW, I1)), + make_pass_through_transform(K)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7, 8>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(K1)), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}, Sequence<8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // B weight tensor + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_k_z_y_x_c_grid_desc, + make_tuple( + make_pass_through_transform(K), + make_embed_transform(make_tuple(ZDot, ZTilde), + make_tuple(ConvStrideD / GcdStrideDilationD, I1)), + make_embed_transform(make_tuple(YDot, YTilde), + make_tuple(ConvStrideH / GcdStrideDilationH, I1)), + make_embed_transform(make_tuple(XDot, XTilde), + make_tuple(ConvStrideW / GcdStrideDilationW, I1)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_slice_transform(ZDot, I0, ZDotSlice), + make_slice_transform(YDot, I0, YDotSlice), + make_slice_transform(XDot, I0, XDotSlice), + make_freeze_transform(i_ztilde), + make_freeze_transform(i_ytilde), + make_freeze_transform(i_xtilde), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0, 1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<5>{})); + + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)), + make_pass_through_transform(C), + make_pass_through_transform(K1)), + make_tuple(Sequence<2, 3, 4, 0>{}, Sequence<5>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // C: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(ZTilde, DTilde), + make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(YTilde, HTilde), + make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(XTilde, WTilde), + make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_freeze_transform(i_ztilde), + make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(i_ytilde), + make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(i_xtilde), + make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0}); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1, + 1, + 1, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {0, 0, 0}); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< + BlockSize, + ABDataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder, + 7, // CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_wei_grid}, + p_c_grid_{p_in_grid}, + Conv_N_{N}, + Conv_K_{K}, + Conv_C_{C}, + input_spatial_lengths_{input_spatial_lengths}, + filter_spatial_lengths_{filter_spatial_lengths}, + output_spatial_lengths_{output_spatial_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + CreateABCDesc(); + } + + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideW = conv_filter_strides_[0]; + const index_t ConvDilationW = conv_filter_dilations_[0]; + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t X = filter_spatial_lengths_[0]; + + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideH = conv_filter_strides_[0]; + const index_t ConvStrideW = conv_filter_strides_[1]; + + const index_t ConvDilationH = conv_filter_dilations_[0]; + const index_t ConvDilationW = conv_filter_dilations_[1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Y = filter_spatial_lengths_[0]; + const index_t X = filter_spatial_lengths_[1]; + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + } + } + } + template ::type = false> + void CreateABCDesc() + { + const index_t ConvStrideD = conv_filter_strides_[0]; + const index_t ConvStrideH = conv_filter_strides_[1]; + const index_t ConvStrideW = conv_filter_strides_[2]; + + const index_t ConvDilationD = conv_filter_dilations_[0]; + const index_t ConvDilationH = conv_filter_dilations_[1]; + const index_t ConvDilationW = conv_filter_dilations_[2]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = ConvStrideD / GcdStrideDilationD; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + const index_t Z = filter_spatial_lengths_[0]; + const index_t Y = filter_spatial_lengths_[1]; + const index_t X = filter_spatial_lengths_[2]; + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + if(ZDotSlice * YDotSlice * XDotSlice <= 0) + { + continue; + } + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_, + {i_ztilde, i_ytilde, i_xtilde}); + a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); + b_grid_desc_k0_n_k1_container_.push_back(descs[I1]); + c_grid_desc_m_n_container_.push_back(descs[I2]); + } + } + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + std::vector a_grid_desc_k0_m_k1_container_; + std::vector b_grid_desc_k0_n_k1_container_; + std::vector c_grid_desc_m_n_container_; + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + // for checking IsSupportedArgument() + index_t Conv_N_; + index_t Conv_K_; + index_t Conv_C_; + + std::vector input_spatial_lengths_; + std::vector filter_spatial_lengths_; + std::vector output_spatial_lengths_; + std::vector conv_filter_strides_; + std::vector conv_filter_dilations_; + std::vector input_left_pads_; + std::vector input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "arg.a_grid_desc_k0_m_k1{" + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1{" + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}" + << std::endl; + + std::cout << "arg.c_grid_desc_m_n{" + << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", " + << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}" + << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const auto [gdx, gdy, gdz] = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_container_[i]); + + const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * + arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + else + { + const auto kernel = + kernel_gemm_xdlops_v2r3; + + ave_time += launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i]); + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], + arg.b_grid_desc_k0_n_k1_container_[i], + arg.c_grid_desc_m_n_container_[i])) + { + return false; + } + } + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(InDataType* p_in_grid, + const WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(void* p_in_grid, + const void* p_wei_grid, + const void* p_out_grid, + ck::index_t N, + ck::index_t K, + ck::index_t C, + std::vector input_spatial_lengths, + std::vector filter_spatial_lengths, + std::vector output_spatial_lengths, + std::vector conv_filter_strides, + std::vector conv_filter_dilations, + std::vector input_left_pads, + std::vector input_right_pads, + InElementwiseOperation, + WeiElementwiseOperation, + OutElementwiseOperation) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + N, + K, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceConvNdBwdDataNwcKxcNwk_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">"; + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0){ + + str<< " Filter1x1Stride1Pad0"; + } + + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp new file mode 100755 index 000000000000..bdc6dc998180 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_dynamic_vector_dims_impl.hpp @@ -0,0 +1,424 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceElementwiseImpl + : public DeviceElementwise +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static index_t GetLowestStrideDim(const std::array& strides) + { + index_t most_continous_dim = NumDim - 1; + index_t most_continous_dim_stride = strides[most_continous_dim]; + for(index_t dim = 0; dim < NumDim; dim++) + { + if(strides[dim] < most_continous_dim_stride) + { + most_continous_dim_stride = strides[dim]; + most_continous_dim = dim; + } + } + return most_continous_dim; + } + + template + static auto PadInputOutputDescriptor(const InOutDescriptor& desc) + { + const auto M0 = desc.GetLength(I0); + const auto M1 = desc.GetLength(I1); + const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0; + const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1; + + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return padded_desc; + } + + static auto GenerateBatchDimsLenghtsTuple(const std::array& lengths, + const index_t M0_dim, + const index_t M1_dim) + { + // Generate batch dims, they will be merged to M0 + // Add one more dim than needed in case that M0 is equal to M1 + // If M0 is equal to M1, then will be one more batch dim + std::array batch_dims; + index_t batch_dim = 0; + for(index_t i = 0; i < NumDim; i++) + { + if(i != M0_dim && i != M1_dim) + { + batch_dims[batch_dim] = lengths[i]; + batch_dim++; + } + } + // Add dummy dim if M0_dim is not equal to M1_dim + if(M0_dim != M1_dim && NumDim >= 2) + batch_dims[NumDim - 2] = 1; + return generate_tuple([&](auto I) { return batch_dims[I]; }, Number{}); + } + + static auto MakeDescriptor(const std::array& lengths, + const std::array& in_strides, + const std::array& out_strides, + const std::array& desc_strides) + { + const auto M0_dim = GetLowestStrideDim(out_strides); + const auto M1_dim = GetLowestStrideDim(in_strides); + + // If M0_dim is equal to M1_dim, then make M0_dim dummy + const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim]; + const auto M1 = lengths[M1_dim]; + const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim]; + const auto M1_stride = desc_strides[M1_dim]; + + const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim); + const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim); + + const auto desc = make_naive_tensor_descriptor( + concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)), + concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride))); + // Merged batch dims with M0 + const auto transforms = + make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))), + make_pass_through_transform(M1)); + using BatchElemsSequence = + typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type; + const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence{}); + const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{}); + // desc: (merged_dims + M0, M1) + auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims); + return PadInputOutputDescriptor(merged_desc); + } + + template + static auto GenerateInOutGridDescTuple() + { + std::array ones; + for(index_t d = 0; d < NumDim; d++) + { + ones[d] = 1; + } + + return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); }, + Number{}); + }; + + using InGridDescTuple = decltype(GenerateInOutGridDescTuple()); + using OutGridDescTuple = decltype(GenerateInOutGridDescTuple()); + + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwiseOp = GridwiseElementwise; + + using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op) + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto in_grid_desc_tuple = generate_tuple( + [&](auto src_i) { + // Use Strides from first tensor to assert that M0 dim and + // M1 dim are the same for each tensor. + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.inStridesArray_[src_i]); + }, + Number{}); + + auto out_grid_desc_tuple = generate_tuple( + [&](auto dst_i) { + return MakeDescriptor(arg.lengths_, + arg.inStridesArray_[I0], + arg.outStridesArray_[I0], + arg.outStridesArray_[dst_i]); + }, + Number{}); + + const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number{}); + + const auto block_2_tile_map = Block2TileMap(M0, M1); + const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1); + + const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) == + GetLowestStrideDim(arg.outStridesArray_[I0]); + + const auto kernel = in_out_same_vector_dim + ? kernel_elementwise + : kernel_elementwise; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + in_grid_desc_tuple, + out_grid_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + block_2_tile_map, + arg.elementwise_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]); + const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]); + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector, + index_t M_dim) { + if(scalarPerVector == 1) + { + return true; + } + if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0) + { + return true; + } + return false; + }; + + bool is_valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 && + M1PerThread % InScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 && + M1PerThread % OutScalarPerVectorSeq::At(I) == 0); + is_valid &= IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim); + }); + + return is_valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseImpl<"; + str << NumDim << ", "; + str << BlockSize << ", "; + str << M0PerBlock << ", "; + str << M1PerBlock << ", "; + str << M0PerThread << ", "; + str << M1PerThread << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp new file mode 100755 index 000000000000..c3416758d127 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp @@ -0,0 +1,597 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// X = Elementwise(input1, input2, input3, ...) +// Y = Normalization(X, beta, gamma) +namespace ck { +template // Descriptor of inputs, Gamma, Beta +__global__ void kernel_elementwise_layernorm( + const InGrid2dDescTuple in_grid_2d_desc_tuple, // Descriptor tuple of inputs + const GridDesc_M_K x_grid_desc_m_k, // Descriptor of X + const GridDesc_M_K gamma_grid_desc_m_k, // Descriptor of gamma + const GridDesc_M_K beta_grid_desc_m_k, // Descriptor of beta + const GridDesc_M_K y_grid_desc_m_k, // Descriptor of Y + index_t num_k_block_tile_iteration, // + AccDataType epsilon, // Datatype of epsilon + const InDataTypePointerTuple p_in_global_tuple, // Ptr tuple of input matrixs + const GammaDataType* const __restrict__ p_gamma_global, // Ptr of gamma + const BetaDataType* const __restrict__ p_beta_global, // Ptr of beta + YDataType* const __restrict__ p_y_global, // Ptr of y + const XElementwiseOperation x_elementwise_op, // Operation of input + const YElementwiseOperation y_elementwise_op) // Operation of output of normalization +{ + extern __shared__ XDataType p_x_lds[]; + GridwiseElementwiseReduction::Run(in_grid_2d_desc_tuple, // Descriptor tuple of inputs + x_grid_desc_m_k, // Descriptor of X + gamma_grid_desc_m_k, // Descriptor of Gamma + beta_grid_desc_m_k, // Descriptor of Beta + y_grid_desc_m_k, // Descriptor of Y + num_k_block_tile_iteration, // + epsilon, // epsilon + p_in_global_tuple, // Ptr tuple of inputs + p_x_lds, // Ptr of X + p_gamma_global, // Ptr of gamma + p_beta_global, // Ptr of beta + p_y_global, // Ptr of Y + x_elementwise_op, // Operation of input + y_elementwise_op); // Operation of output of normalization +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = LayerNorm(A + B, Beta, Gamma) +template // Size to write destination Y +struct DeviceElementwiseNormalizationImpl + : public DeviceElementwiseNormalization +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + + using XDataType = YDataType; + + static_assert( + (KThreadSliceSize % GammaSrcVectorSize == 0), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (KThreadSliceSize % BetaSrcVectorSize == 0), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + static constexpr index_t M_BlockTileSize = + MThreadClusterSize * MThreadSliceSize; // num of rows calculated in a block + static constexpr index_t K_BlockTileSize = + KThreadClusterSize * KThreadSliceSize; // num of columns calculated in a block + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t numSrcDim = Rank; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + template + static auto GenerateSrcGrid2dDescTuple(Number) + { + return generate_tuple([&](auto) { return MakeSrc2dDescriptor({1}, {1}, 1, 1); }, + Number{}); + }; + + using InGrid2dDescTuple = decltype(GenerateSrcGrid2dDescTuple(Number{})); + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseReduceLayernormGeneric = + GridwiseElementwiseLayernormWelfordVariance_mk_to_mk; + + using GridwiseReduceLayernormSweepOnce = + GridwiseElementwiseLayernormWelfordVariance_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::array, NumInput> inStridesArray, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + XElementwiseOperation x_elementwise_op, + YElementwiseOperation y_elementwise_op, + double epsilon, + const std::array in_dev_buffers, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y) + : p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + x_elementwise_op_(x_elementwise_op), + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + for(int i = 0; i < NumInput; i++) + { + inStridesArray_[i] = + shuffle_tensor_dimensions(inStridesArray[i], reduceDims); + } + + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(Lengths_); + + blkGroupSize_ = 1; + numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize_; + + in_grid_2d_desc_tuple_ = generate_tuple( + [&](auto I) { + return MakeSrc2dDescriptor( + Lengths_, inStridesArray_[I.value], blkGroupSize_, numBlockTileIteration_); + }, + Number{}); + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, blkGroupSize_, numBlockTileIteration_); + + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, blkGroupSize_, numBlockTileIteration_); + + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, blkGroupSize_, numBlockTileIteration_); + + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, blkGroupSize_, numBlockTileIteration_); + + sweep_once_ = + x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + if(!sweep_once_) // if not sweep once, compute memory size for matrix X in lds for + // store Intermediate results + { + int block_TileSize = M_BlockTileSize * reduce_total_length; + x_lds_size_ = block_TileSize * sizeof(XDataType); + } + else + x_lds_size_ = 0; + } + + AccDataType epsilon_; + + InDataTypePointerTuple in_dev_buffers_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + + std::vector Lengths_; + std::array, NumInput> inStridesArray_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + + XElementwiseOperation x_elementwise_op_; + YElementwiseOperation y_elementwise_op_; + + int blkGroupSize_; + int numBlockTileIteration_; + size_t gridSize_; + + InGrid2dDescTuple in_grid_2d_desc_tuple_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K beta_grid_desc_m_k_; + GridDesc_M_K y_grid_desc_m_k_; + bool sweep_once_; + int x_lds_size_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = + arg.sweep_once_ ? kernel_elementwise_layernorm + : kernel_elementwise_layernorm; + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + arg.x_lds_size_, + arg.in_grid_2d_desc_tuple_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.in_dev_buffers_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.x_elementwise_op_, + arg.y_elementwise_op_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + if constexpr(XYSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + for(int i = 0; i < NumInput; i++) + { + if(p_arg_->inStridesArray_[i][NumInvariantDim - 1] != 1) + return false; + } + + if(p_arg_->inStridesArray_[0][NumInvariantDim - 1] != 1 && + p_arg_->inStridesArray_[1][NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0) + return false; + }; + } + else + { + for(int i = 0; i < NumInput; i++) + { + if(p_arg_->inStridesArray_[i][Rank - 1] != 1) + return false; + } + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + }; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + + auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) { + bool ret = true; + + if(!isLastDimensionCoalesced) + ret = scalarPerVector == 1; + else + ret = KThreadSliceSize % scalarPerVector == 0; + + return ret; + }; + + if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize)) + return false; + + if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize)) + return false; + + // if fastest dim is not reduced + if constexpr(XYSrcVectorDim == 0) // + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + + // if fastest dim is not reduced + if constexpr(XYSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return (false); + } + + if(p_arg_->x_lds_size_ >= 65536) + { + return (false); + } + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::array, NumInput> inStridesArray, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector reduceDims, + double epsilon, + const std::array in_dev_buffers, + const void* p_gamma, + const void* p_beta, + void* p_y, + XElementwiseOperation x_elementwise_op, + YElementwiseOperation y_elementwise_op) override + { + return std::make_unique(lengths, + inStridesArray, + gammaStrides, + betaStrides, + yStrides, + reduceDims, + x_elementwise_op, + y_elementwise_op, + epsilon, + in_dev_buffers, + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseNormalizationImpl<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp new file mode 100755 index 000000000000..dff0530ee026 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp @@ -0,0 +1,346 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_elementwise_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \note This structure is deprecated (left for backwards compatibility). Please use + * DeviceElementwiseImpl from device_elementwise_dynamic_vector_dims_impl.hpp. + */ +template +struct DeviceElementwiseImpl : public DeviceElementwise +{ + static constexpr int NumInput = InDataTypeTuple::Size(); + static constexpr int NumOutput = OutDataTypeTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static auto GenerateInDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple()); + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * MPerThread; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(const std::array& lengths, + const std::array& stride, + index_t gridSize, + index_t blockSize) + { + auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number{}); + auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number{}); + + // nd desc - [s0, s1, s2, ...] + const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride); + + // merge nd to 1d desc - [s0 * s1 * ...] + if constexpr(NumDim > 1) + { + const auto desc_m = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(tupleOfShape)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{})), + make_tuple(Sequence<0>{})); + + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + else + return PadDescriptor_M_1d(desc, gridSize, blockSize); + } + + template + static auto GenerateInOutGrid1dDescTuple(Number) + { + return generate_tuple( + [&](auto) { + if constexpr(NumDim > 1) + { + return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1); + } + else + { + return MakeDescriptor_M({1}, {1}, 1, 1); + }; + }, + Number{}); + }; + + using InGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + using OutGrid1dDescTuple = decltype(GenerateInOutGrid1dDescTuple(Number{})); + + using GridwiseElementwise = GridwiseElementwise_1D; + + struct Argument : public BaseArgument + { + Argument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) + + : lengths_(lengths), + inStridesArray_(inStridesArray), + outStridesArray_(outStridesArray), + elementwise_op_(elementwise_op), + unary_op_(unary_op), + scale_op_(scale_op), + blockSize_(256) + { + in_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(in_dev_buffers[I.value]); + }, + Number{}); + + out_dev_buffers_ = generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + return static_cast(out_dev_buffers[I.value]); + }, + Number{}); + } + + InDataTypePointerTuple in_dev_buffers_; + OutDataTypePointerTuple out_dev_buffers_; + + std::array lengths_; + std::array, NumInput> inStridesArray_; + std::array, NumOutput> outStridesArray_; + + ElementwiseOperation elementwise_op_; + UnaryOperation unary_op_; + Scale scale_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + + auto in_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + auto out_grid_1d_desc_tuple = generate_tuple( + [&](auto I) { + return MakeDescriptor_M( + arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_); + }, + Number{}); + + const auto kernel = kernel_elementwise_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_1d_desc_tuple, + out_grid_1d_desc_tuple, + arg.in_dev_buffers_, + arg.out_dev_buffers_, + arg.elementwise_op_, + arg.unary_op_, + arg.scale_op_); + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.lengths_.back() % MPerThread != 0) + return false; + + auto IsScalarPerVectorValid = [&](const std::array& lengths, + const std::array& strides, + index_t scalarPerVector) { + if(strides.back() == 1 && lengths.back() % scalarPerVector == 0) + return true; + + if(strides.back() != 1 && scalarPerVector == 1) + return true; + + return false; + }; + + bool valid = true; + static_for<0, NumInput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) + valid = false; + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + if(!IsScalarPerVectorValid( + arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) + valid = false; + }); + + return valid; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) + { + return Argument{lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op, + unary_op, + scale_op}; + } + + std::unique_ptr + MakeArgumentPointer(const std::array lengths, + const std::array, NumInput> inStridesArray, + const std::array, NumOutput> outStridesArray, + const std::array in_dev_buffers, + const std::array out_dev_buffers, + ElementwiseOperation elementwise_op, + UnaryOperation unary_op, + Scale scale_op) override + { + return std::make_unique(lengths, + inStridesArray, + outStridesArray, + in_dev_buffers, + out_dev_buffers, + elementwise_op, + unary_op, + scale_op); + } + + static auto MakeInvoker() { return Invoker{}; } + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceElementwiseNormalizationImpl<"; + str << NumDim << ", "; + str << MPerThread << ">"; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp new file mode 100755 index 000000000000..553143e2861e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_fpAintB_gemm_wmma.hpp @@ -0,0 +1,714 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_dequantB.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N) +// 2. C(M, N) = A(M, K) * DequantB(K, N) + +template +struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + // LDS bypass feature not implemented for dequantization pipeline. + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + using DeviceOp = DeviceFpAintBGemm_Wmma_CShuffle; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else if constexpr(is_same_v) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0) + { + // assume Scale is [1, N] + const auto scale_grid_desc_n_k = [&]() { + const auto scale_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw); + }(); + + const auto N = scale_grid_desc_n_k.GetLength(I0); + const auto K = scale_grid_desc_n_k.GetLength(I1); + // When K = 1, it might be scale tensor. + assert(K % K1 == 0 && K != 1); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1 + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + scale_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseFpAintBGemm_Wmma< + BlockSize, + ADataType, + BDataType, + ScaleDataType, + AccDataType, + CShuffleDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc, + BGridDesc, + ScaleGridDesc, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const ScaleDataType* p_scale_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_scale_grid_{p_scale_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_{}, + b_grid_desc_{}, + scale_grid_desc_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + MRaw_{M}, + NRaw_{N}, + KRaw_{K} + { + a_grid_desc_ = DeviceOp::MakeAGridDescriptor(M, K, StrideA); + b_grid_desc_ = DeviceOp::MakeBGridDescriptor(K, N, StrideB); + scale_grid_desc_ = DeviceOp::MakeScaleGridDescriptor(K, N, 0); + c_grid_desc_m_n_ = DeviceOp::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_, b_grid_desc_, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const ScaleDataType* p_scale_grid_; + CDataType* p_c_grid_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; + ScaleGridDesc scale_grid_desc_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_fpAintB_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + ScaleDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_scale_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_, + arg.scale_grid_desc_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + printf("DeviceOp err: AccDataType"); + return false; + } + } + else + { + printf("DeviceOp err: Arch"); + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_, arg.b_grid_desc_, arg.c_grid_desc_m_n_, arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const ScaleDataType* p_scale, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_scale, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_scale, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_scale), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}, + {PipelineVersion::weight_only, "weight_only"}}; + + // clang-format off + str << "DeviceFpAintBGemm_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp new file mode 100755 index 000000000000..5bfef60af285 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_bias_add_reduce_xdl_cshuffle.hpp @@ -0,0 +1,888 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmBiasAddReduce_Xdl_CShuffle : public DeviceGemmReduce<1, ReduceOperations::Size()> +{ + using DeviceOp = DeviceGemmBiasAddReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume D is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 0)); + using C1GridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + BiasDataType, + D0DataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_M_N, + C1GridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BiasDataType* p_bias_grid, + const D0DataType* p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideC1, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + D0ElementwiseOperation d0_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_bias_grid_{p_bias_grid}, + p_d0_grid_{p_d0_grid}, + p_reduces_grid_{p_reduces_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, 0)}, + c1_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC1)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c1_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + d0_element_op_{d0_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c0_grid_desc_m_n_); + + c1_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c1_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const BiasDataType* p_bias_grid_; + const D0DataType* p_d0_grid_; + ReducePtrsGlobal p_reduces_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_M_N c0_grid_desc_m_n_; + C1GridDesc_M_N c1_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + D0ElementwiseOperation d0_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_bias_add_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + BiasDataType, + D0DataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + D0ElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_bias_grid_, + arg.p_d0_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.d0_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c1_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + index_t /* KBatch */ = 1) override + { + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + D0ElementwiseOperation d_element_op = + *(static_cast(d_element_ops[0])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_bias), + static_cast(p_ds[0]), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideDs[0], + a_element_op, + b_element_op, + c_element_op, + d_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmBiasAddReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp new file mode 100755 index 000000000000..7faee161c11f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp @@ -0,0 +1,646 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template < + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGemmDl : public DeviceGemm + +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + M_raw_{M}, + N_raw_{N}, + K_raw_{K}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = DeviceGemmDl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmDl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmDl::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused, but may be useful in future. + index_t M01_; + index_t N01_; + + index_t M_raw_; + index_t N_raw_; + index_t K_raw_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmDl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1)); + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + float ave_time = 0; + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = + kernel_gemm_dl_v1r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // Make sure that the M, N, K dimensions before padding are divisible by respective vector + // lengths. + if constexpr(is_same::value) + { + constexpr auto A_K_vec_length = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3); + if(arg.K_raw_ % A_K_vec_length != 0) + { + return false; + } + } + else + { + constexpr auto A_M_vec_lenght = + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) * + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2); + if(arg.M_raw_ % A_M_vec_lenght != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + constexpr auto B_N_vec_lenght = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2); + if(arg.N_raw_ % B_N_vec_lenght != 0) + { + return false; + } + } + else + { + constexpr auto B_K_vec_length = + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) * + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3); + if(arg.K_raw_ % B_K_vec_length != 0) + { + return false; + } + } + + if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || + ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + virtual std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmDl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp new file mode 100755 index 000000000000..168b9ad81210 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp @@ -0,0 +1,271 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmDpp : public DeviceGemm +{ + using GridwiseGemm = GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp< + BlockSize, + ADataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + KPerBlock, + MPerDpp, + NPerDpp, + AK1, + BK1, + MDppPerWave, + NDppPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_dpp; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::is_gfx103_supported() || ck::is_gfx11_supported()) + { + return GridwiseGemm::CheckValidity(karg); + } + return false; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmDpp" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerDpp << ", " + << NPerDpp << ", " + << MDppPerWave << ", " + << MDppPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100755 index 000000000000..4d9dd192c986 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,751 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_abd.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleABD_Xdl_CShuffle : public DeviceGemmMultipleABD +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + AsDataType, + BsDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else +#endif + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { +#if 0 + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else +#endif + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout_ = remove_cvref_t>; + + static_assert(is_same::value, ""); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout_ = remove_cvref_t>; + + static_assert(is_same::value, ""); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout_ = remove_cvref_t>; + + static_assert(is_same::value, ""); + }); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(std::array p_as, + std::array p_bs, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + std::array StrideAs, + std::array StrideBs, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideAs, + StrideBs, + StrideDs, + StrideE, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \ + defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__)) + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); + + __shared__ ABDataType p_shared[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = e_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = block_2_ctile_map; +#endif +} +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template && + is_same_v, + bool> = false> +struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD + +{ + using DeviceOp = DeviceGemmMultipleD_Dl; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using EGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + e_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmMultipleD_Dl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmMultipleD_Dl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = + DeviceGemmMultipleD_Dl::MakeEGridDescriptor_M_N(M, N, StrideE); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, e_grid_desc_m_n_)) + { + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1_); + + ds_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_); + + e_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + DefaultBlock2CTileMap block_2_ctile_map_; + + // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmMultipleD_Dl::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + { + std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{" + << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{" + << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize( + arg.e_grid_desc_m_n_.GetLength(I0), arg.e_grid_desc_m_n_.GetLength(I1)); + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = + kernel_gemm_dl_multiple_d; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.e_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_); + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmMultipleD_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp new file mode 100755 index 000000000000..47fb630ea903 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp @@ -0,0 +1,1088 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EMeanVarDataType* __restrict__ p_e_grid, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + mean_var_grid_desc_mblock_mperblock_nblock, + const CountGridDescriptor_MBlock_MPerBlock_NBlock + count_grid_desc_mblock_mperblock_nblock, + const Block2ETileMap block_2_etile_map, + index_t NRaw) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemmWelford::GetSharedMemoryNumberOfByte()]; + + GridwiseGemmWelford::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_welford_mean_grid, + p_welford_var_grid, + p_welford_count_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + mean_var_grid_desc_mblock_mperblock_nblock, + count_grid_desc_mblock_mperblock_nblock, + block_2_etile_map, + NRaw); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = p_welford_mean_grid; + ignore = p_welford_var_grid; + ignore = p_welford_count_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = mean_var_grid_desc_mblock_mperblock_nblock; + ignore = count_grid_desc_mblock_mperblock_nblock; + ignore = block_2_etile_map; + ignore = NRaw; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_welford_layernorm2d_second_half( + const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N e_grid_desc_m_n, + const EHGridDesc_M_N h_grid_desc_m_n, + const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, + const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, + const GammaBetaGridDesc_N gamma_grid_desc_n, + const GammaBetaGridDesc_N beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) +{ + GridwiseWelfordLayernorm::Run(p_e_grid, + p_in_welford_mean_grid, + p_in_welford_var_grid, + p_in_welford_count_grid, + p_gamma_grid, + p_beta_grid, + p_h_grid, + e_grid_desc_m_n, + h_grid_desc_m_n, + mean_var_grid_desc_m_nblock, + count_grid_desc_m_nblock, + gamma_grid_desc_n, + beta_grid_desc_n, + numMeanVarCountBlockTileIteration_N, + NBlockClusterLength, + epsilon, + h_element_op); +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : H[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// H = layernorm(E) +// Assume: +// D0, D1, ... and E have the same layout +// Calculate mean & variance along N dimension in layernorm(E) +template +struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle + : public DeviceGemmMultipleDLayernorm +{ + // EDataType, MeanDataType and VarDataType must be the same. + // eg. M, N, K = [1, 1, 1], + // in case of layernorm, divisor = 1 / sqrt(var + 1e-5) = 316.227783 + // if (x - mean) != 0, (x - mean) * divisor * gamma might be too large + // However, (x - mean) * divisor * gamma should be 0 in this case + + using DeviceOp = DeviceGemmMultipleDLayernorm_Xdl_CShuffle; + using ELayout = HLayout; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t LayernormHDstVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormGammaSrcVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormBetaSrcVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormESrcVectorSize = PostShuffleScalarPerVector; + static constexpr index_t LayernormThreadSliceSize_N = PostShuffleScalarPerVector; + using LayernormBlockTileSize_M_N = + Sequence; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = MatrixPadder{ + GemmMPerBlock, GemmNPerBlock, GemmKPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride) + { + // Only support row major for E and H + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(Stride, I1)); + return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + static_assert(is_same::value); + + return DeviceOp:: + MakeEHGridDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>( + MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + template + static auto MakeMeanVarDescriptor_M_N(index_t M, index_t N) + { + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + template + static auto MakeCountDescriptor_M_N(index_t M, index_t N) + { + // We will broadcast [N] to [M, N] in this descriptor + // Hence, 1st stride is 0 + const auto grid_desc_m_n = + make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, I1)); + return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{}); + } + + template + static auto MakeDescriptor_X(index_t X) + { + const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X)); + return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + // We have to separate mean var descriptor for gemm and layernorm bacause of different grid + // layout(different padding) + using GemmMeanVarGridDesc_M_NBlock = + decltype(MakeMeanVarDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, + 1)); + + using GemmCountGridDesc_M_NBlock = + decltype(MakeCountDescriptor_M_N, GemmMPerBlock, GemmNPerBlock>(1, + 1)); + + using LayernormMeanVarGridDesc_M_NBlock = + decltype(MakeMeanVarDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(1, 1)); + + using LayernormCountGridDesc_M_NBlock = + decltype(MakeCountDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(1, 1)); + + using GammaBetaGridDesc_N = decltype(MakeDescriptor_X(1)); + using EHGridDesc_M_N = decltype(MakeEHGridDescriptor_M_N, 1, 1>(1, 1, 1)); + + using GridwiseGemmWelford = GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EMeanVarDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_M_K, + BGridDesc_N_K, + DsGridDesc_M_N, + EHGridDesc_M_N, + GemmMeanVarGridDesc_M_NBlock, + GemmCountGridDesc_M_NBlock, + NumGemmKPrefetchStage, + BlockSize, + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + PostShuffleThreadClusterSize_M_N, + PostShuffleScalarPerVector, + LoopSched, + PipelineVer>; + + using Block2ETileMap = typename GridwiseGemmWelford::DefaultBlock2ETileMap; + + using GridwiseWelfordLayernorm = + GridwiseWelfordSecondHalfLayernorm2d; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + const void* p_gamma_grid, + const void* p_beta_grid, + void* p_h_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_workspace_e_grid_{nullptr}, + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + p_gamma_grid_{static_cast(p_gamma_grid)}, + p_beta_grid_{static_cast(p_beta_grid)}, + p_h_grid_{static_cast(p_h_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + gemm_e_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + GemmMPerBlock, + GemmNPerBlock>(MRaw, NRaw, StrideH)}, + layernorm_e_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, NRaw, StrideH)}, + gemm_mean_var_grid_desc_m_nblock_{}, + gemm_count_grid_desc_m_nblock_{}, + layernorm_mean_var_grid_desc_m_nblock_{}, + layernorm_count_grid_desc_m_nblock_{}, + gamma_grid_desc_n_{ + DeviceOp::MakeDescriptor_X(NRaw)}, + beta_grid_desc_n_{ + DeviceOp::MakeDescriptor_X(NRaw)}, + h_grid_desc_m_n_{ + DeviceOp::MakeEHGridDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, NRaw, StrideH)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemmWelford::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemmWelford::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + block_2_etile_map_{ + GridwiseGemmWelford::MakeDefaultBlock2ETileMap(gemm_e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + h_element_op_{h_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw}, + gemm_nblock_{math::integer_divide_ceil(NRaw, GemmNPerBlock)}, + epsilon_{static_cast(epsilon)} + { + // We don't need to pad in N dimension in gemm for mean/var/count. Set NPerTile 1. + gemm_mean_var_grid_desc_m_nblock_ = + DeviceOp::MakeMeanVarDescriptor_M_N, GemmMPerBlock, 1>( + MRaw, gemm_nblock_); + + gemm_count_grid_desc_m_nblock_ = + DeviceOp::MakeCountDescriptor_M_N, GemmMPerBlock, 1>( + MRaw, gemm_nblock_); + + layernorm_mean_var_grid_desc_m_nblock_ = + DeviceOp::MakeMeanVarDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>( + MRaw, gemm_nblock_); + + layernorm_count_grid_desc_m_nblock_ = + DeviceOp::MakeCountDescriptor_M_N, + LayernormBlockTileSize_M_N::At(0), + LayernormBlockTileSize_M_N::At(1)>(MRaw, + gemm_nblock_); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEHGridDescriptor_M_N, + GemmMPerBlock, + GemmNPerBlock>(MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E/mean/var/count + if(GridwiseGemmWelford::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + gemm_e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemmWelford::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemmWelford::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + gemm_e_grid_desc_m_n_); + + gemm_mean_var_grid_desc_mblock_mperblock_nblock_ = + GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_mean_var_grid_desc_m_nblock_); + + gemm_count_grid_desc_mblock_mperblock_nblock_ = + GridwiseGemmWelford::MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock( + gemm_count_grid_desc_m_nblock_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << gemm_e_grid_desc_m_n_ << std::endl; + std::cout << "H[M, N]: " << h_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemmWelford::DsGridPointer p_ds_grid_; + void* p_workspace_e_grid_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + const GammaDataType* p_gamma_grid_; + const BetaDataType* p_beta_grid_; + HDataType* p_h_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EHGridDesc_M_N gemm_e_grid_desc_m_n_; + EHGridDesc_M_N layernorm_e_grid_desc_m_n_; + GemmMeanVarGridDesc_M_NBlock gemm_mean_var_grid_desc_m_nblock_; + GemmCountGridDesc_M_NBlock gemm_count_grid_desc_m_nblock_; + LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_; + LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_; + GammaBetaGridDesc_N gamma_grid_desc_n_; + GammaBetaGridDesc_N beta_grid_desc_n_; + EHGridDesc_M_N h_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseGemmWelford::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemmWelford::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock + gemm_mean_var_grid_desc_mblock_mperblock_nblock_; + typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock + gemm_count_grid_desc_mblock_mperblock_nblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + HElementwiseOperation h_element_op_; + + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + index_t gemm_nblock_; + AccDataType epsilon_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0; + + if(!GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting"); + } + if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr || + arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + index_t grid_size = arg.block_2_etile_map_.CalculateGridSize(arg.gemm_e_grid_desc_m_n_); + + const auto M = arg.h_grid_desc_m_n_.GetLength(I0); + const auto N = arg.h_grid_desc_m_n_.GetLength(I1); + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel_gemm_welford = + kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle< + GridwiseGemmWelford, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemmWelford::DsGridPointer, + EMeanVarDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemmWelford::DefaultAGridDesc_AK0_M_AK1, + typename GridwiseGemmWelford::DefaultBGridDesc_BK0_N_BK1, + typename GridwiseGemmWelford:: + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemmWelford:: + EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemmWelford::MeanVarGridDescriptor_MBlock_MPerBlock_NBlock, + typename GridwiseGemmWelford::CountGridDescriptor_MBlock_MPerBlock_NBlock, + typename GridwiseGemmWelford::DefaultBlock2ETileMap, + has_main_loop>; + + const auto kernel_welford_layernorm = + kernel_welford_layernorm2d_second_half; + + avg_time += + launch_and_time_kernel(stream_config, + kernel_gemm_welford, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + static_cast(arg.p_workspace_e_grid_), + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.gemm_mean_var_grid_desc_mblock_mperblock_nblock_, + arg.gemm_count_grid_desc_mblock_mperblock_nblock_, + arg.block_2_etile_map_, + arg.NRaw_); + + index_t MBlockClusterLength = + math::integer_divide_ceil(M, LayernormBlockTileSize_M_N::At(0)); + index_t NBlockClusterLength = + math::integer_divide_ceil(N, LayernormBlockTileSize_M_N::At(1)); + grid_size = MBlockClusterLength * NBlockClusterLength; + + index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil( + arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1)); + + avg_time += launch_and_time_kernel( + stream_config, + kernel_welford_layernorm, + dim3(grid_size), + dim3(BlockSize), + 0, + static_cast(arg.p_workspace_e_grid_), + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_gamma_grid_, + arg.p_beta_grid_, + arg.p_h_grid_, + arg.layernorm_e_grid_desc_m_n_, + arg.h_grid_desc_m_n_, + arg.layernorm_mean_var_grid_desc_m_nblock_, + arg.layernorm_count_grid_desc_m_nblock_, + arg.gamma_grid_desc_n_, + arg.beta_grid_desc_n_, + numMeanVarCountBlockTileIteration_N, + NBlockClusterLength, + arg.epsilon_, + arg.h_element_op_); + + return avg_time; + }; + + if(GridwiseGemmWelford::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_; + + // workspace for welford intermediate mean + workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 64; + + if constexpr(!is_same_v) + workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType); + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + + index_t count_space_sz = gemm_welford_size * sizeof(int32_t); + count_space_sz = math::integer_least_multiple(count_space_sz, 64); + + if constexpr(!is_same_v) + pArg_->p_workspace_e_grid_ = + reinterpret_cast(pArg_->p_workspace_count_) + count_space_sz; + else + pArg_->p_workspace_e_grid_ = static_cast(pArg_->p_h_grid_); + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // E and H only support RowMajor for now + if constexpr(is_same_v && is_same_v) + { + if(arg.NRaw_ % PostShuffleScalarPerVector != 0 || + arg.NRaw_ % LayernormGammaSrcVectorSize != 0 || + arg.NRaw_ % LayernormBetaSrcVectorSize != 0 || + arg.NRaw_ % LayernormHDstVectorSize != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemmWelford::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.gemm_e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_gamma, + p_beta, + p_h, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + const void* p_gamma, + const void* p_beta, + void* p_h, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideH, + double epsilon, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + HElementwiseOperation h_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_gamma, + p_beta, + p_h, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideH, + epsilon, + a_element_op, + b_element_op, + cde_element_op, + h_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleDLayernorm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << GemmMPerBlock << ", " + << GemmNPerBlock << ", " + << GemmKPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << PostShuffleThreadClusterSize_M_N::At(I0) << ", " + << PostShuffleThreadClusterSize_M_N::At(I1) << ", " + << LayernormThreadClusterSize_M_N::At(I0) << ", " + << LayernormThreadClusterSize_M_N::At(I1) << ", " + << LayernormThreadSliceSize_M + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; // namespace device + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp new file mode 100755 index 000000000000..c048e7249c30 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -0,0 +1,692 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_multiple_r_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + FloatRsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_rs_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + rs_grid_desc_mblock_mperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = p_rs_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = qs_element_op; + ignore = rs_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = rs_grid_desc_mblock_mperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[AK0, M, AK1] +// input : B[AK0, N, AK1] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : R0[M], R1[M], ... +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ... +// R0 = r_op0(Q0), R1 = r_op1(Q1), ... +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle + : public DeviceGemmMultipleDMultipleR +{ + using DeviceOp = DeviceGemmMultipleDMultipleR_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + // assume D is packed tensor + static auto MakeRGridDescriptor_M(index_t MRaw) + { + const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(r_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return r_grid_desc_mraw; + } + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + using RGridDesc_M = decltype(MakeRGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + DsDataType, + EDataType, + ReduceAccDataType, + RsDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + ThreadReduceOperations, + InMemoryDataOperationEnum::Set, + RsGlobalMemoryDataOperation, + AGridDesc_M_K, + BGridDesc_N_K, + EGridDesc_M_N, + RGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, + CDEReduceThreadTransferScalarPerVector_NPerBlock, + RThreadTransferDstScalarPerVector_MPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + std::array p_rs_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, // FIXME + p_e_grid_{static_cast(p_e_grid)}, + p_rs_grid_{}, // FIXME + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + r_grid_desc_m_{DeviceOp::MakeRGridDescriptor_M(MRaw)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + rs_grid_desc_mblock_mperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + qs_element_op_{qs_element_op}, + rs_element_op_{rs_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + e_grid_desc_m_n_, + r_grid_desc_m_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + const auto d_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + d_grid_desc_m_n); + }); + + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; + + p_rs_grid_(i) = static_cast(p_rs_grid[i]); + + rs_grid_desc_mblock_mperblock_(i) = + GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); + }); + } + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + typename GridwiseGemm::RsGridPointer p_rs_grid_; + + // tensor descriptors + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + EGridDesc_M_N e_grid_desc_m_n_; + RGridDesc_M r_grid_desc_m_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + QsElementwiseOperation qs_element_op_; + RsElementwiseOperation rs_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_multiple_r_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + typename GridwiseGemm::RsGridPointer, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ck::StaticallyIndexedArray< + typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, + NumRTensor>, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.p_rs_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.qs_element_op_, + arg.rs_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.rs_grid_desc_mblock_mperblock_, + arg.block_2_etile_map_); + }; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::array p_rs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + p_rs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + std::array p_rs, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + QsElementwiseOperation qs_element_op, + RsElementwiseOperation rs_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + p_rs, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmMultipleDMultipleR_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp new file mode 100755 index 000000000000..35f1c77f8799 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,740 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Wmma_CShuffle; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + static constexpr auto AEnableLds_auto = + (NWaves == 1 && is_same::value) ? false : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && is_same::value) ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else if constexpr(is_same_v) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& Ms, + const std::array& Ns, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(Ms[i], Ns[i], DsStride[i]); + }, + Number{}); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_Wmma< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc, + BGridDesc, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + AEnableLds, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BEnableLds, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumPrefetch, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc{}, + b_grid_desc{}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{}, + e_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{M}, + NRaw_{N}, + KRaw_{K} + { + a_grid_desc = DeviceOp::MakeAGridDescriptor(M, K, StrideA); + b_grid_desc = DeviceOp::MakeBGridDescriptor(K, N, StrideB); + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideDs[i]); + }); + e_grid_desc_m_n_ = DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideE); + + block_2_ctile_map_ = GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01); + + if(GridwiseOp::CheckValidity(a_grid_desc, + b_grid_desc, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_ctile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + BGridDesc b_grid_desc; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // Idle + index_t M01_; + index_t N01_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b_grid_desc, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I2); + } + else + { + return arg.a_grid_desc.GetLength(I0) * arg.a_grid_desc.GetLength(I3) * + arg.a_grid_desc.GetLength(I4) * arg.a_grid_desc.GetLength(I6); + } + }(); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_gemm_mupltipe_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + remove_reference_t< + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; // Last Option is W/O + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_grid_desc, + arg.b_grid_desc, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b_grid_desc, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..f193b093d157 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,900 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#ifndef __HIPCC_RTC__ +#include +#include +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#endif + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD +{ + using DeviceOp = DeviceGemmMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const Array& MRaws, + const Array& NRaws, + const Array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + // desc for problem definition + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + +#ifndef __HIPCC_RTC__ + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + BDataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + +#endif + + static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_) + { + // check vector load/store + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector load of Ds + // only support RowMajor for now + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // check vector store of E + // only support RowMajor for now + if constexpr(is_same_v) + { + if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + return true; + } + +#ifndef __HIPCC_RTC__ + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and + GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{{LoopScheduler::Default, "Default"}, + { LoopScheduler::Interwave, + "Interwave" }}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + { PipelineVersion::v2, + "v2" }}; + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +#endif + + template + struct Descriptor + { + static constexpr auto ds_tuple() + { + return transform_tuples( + [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); }, + DsDesc{}); + } + using AGridDesc_M_K = + remove_cvref_t; + using BGridDesc_N_K = + remove_cvref_t; + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t; + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_tuple()))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>; + using Block2ETileMap = remove_cvref_t; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k; + BGridDesc_N_K b_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + EGridDesc_M_N e_grid_desc_m_n; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map; + + // element-wise op + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CDEElementwiseOperation cde_element_op; + + // for checking vector load/store + index_t MRaw; + index_t NRaw; + index_t KRaw; + + bool has_main_k_block_loop = true; + + constexpr Descriptor(ADesc a, + BDesc b, + DsDesc ds, + EDesc e, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CDEElementwiseOperation cde_element_op_) + : a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)}, + b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)}, + ds_grid_desc_m_n{transform_tuples( + [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); }, + ds)}, + e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)}, + a_grid_desc_ak0_m_ak1{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)}, + b_grid_desc_bk0_n_bk1{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + transform_tuples( + [&](auto d) constexpr { + return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); + }, + ds))}, + e_grid_desc_mblock_mperblock_nblock_nperblock{ + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)}, + block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)}, + has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop( + a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + cde_element_op{cde_element_op_}, + MRaw{e.GetLength(I0)}, + NRaw{e.GetLength(I1)}, + KRaw{a.GetLength(I1)} + { + } + + constexpr bool IsValid() const + { + return GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map) and + IsSupported(MRaw, NRaw, KRaw); + } + + constexpr index_t GetBlockSize() const { return BlockSize; } + + constexpr index_t GetGridSize() const + { + return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + } + }; + + template + static constexpr auto + make_descriptor(ADesc a, + BDesc b, + DsDesc ds, + EDesc e, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{}) + { + return Descriptor( + a, b, ds, e, a_element_op, b_element_op, cde_element_op); + } + + template + __device__ static void Run(const Desc& desc, + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid) + { + __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +#ifndef __HIPCC_RTC__ + assert(desc.IsValid()); +#endif + if(desc.has_main_k_block_loop) + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); + } + else + { + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared_block, + desc.a_element_op, + desc.b_element_op, + desc.cde_element_op, + desc.a_grid_desc_ak0_m_ak1, + desc.b_grid_desc_bk0_n_bk1, + desc.ds_grid_desc_mblock_mperblock_nblock_nperblock, + desc.e_grid_desc_mblock_mperblock_nblock_nperblock, + desc.block_2_etile_map); + } + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp new file mode 100755 index 000000000000..010a23c66cb5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -0,0 +1,414 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad + : public DeviceGemmMultipleD +{ + static constexpr auto I1 = Number<1>{}; + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + struct Invoker : public BaseInvoker + { + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemm::AGridDesc_AK0_M_AK1, + typename GridwiseGemm::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!ck::is_lds_direct_load_supported()) + { + return false; + } + + // Check vector load/store. + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // Check vector load of A. + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of B. + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of Ds. + // For now, only the RowMajor layout is supported. + bool all_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(!is_same_v) + { + all_valid = false; + } + }); + + if(!all_valid) + { + return false; + } + + // Check vector load of E. + // For now, only the RowMajor layout is supported. + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideDs, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp new file mode 100755 index 000000000000..0796614bd41c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,793 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 + : public DeviceGemmMultipleD_ABScale +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave && + MPerBlock * NPerBlock / BlockSize > 64) + ? 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != + // KPerBlock) + // { + // return false; + // } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle + : public DeviceGemmMultipleDSplitKBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / + 4 * (1 + GridwiseGemm::NWave); + constexpr auto estimated_reg_b = + NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2); + constexpr auto estimated_reg_c = + MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4; + constexpr auto estimated_reg_total = + estimated_reg_a + estimated_reg_b + estimated_reg_c; + + constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + + // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + // has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right + // now"); + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + throw std::runtime_error("todo: only v1 v2 and v3 support now"); + } + } +#if 0 + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { +#if 0 + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } +#endif + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + throw std::runtime_error("todo: only v3 support now"); + } + } +#endif + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle + : public DeviceGemmMultipleD_BlockScale_BPreshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // unconditional 2 to remove agpr usage + constexpr index_t minimum_occupancy = 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK != + // KPerBlock) + // { + // return false; + // } + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + // Padding to release this restriction + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<0, ReduceOperations::Size()> +{ + using DeviceOp = DeviceGemmReduce_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + // assume Reduce is packed tensor + static auto MakeReduceGridDescriptor_M(index_t MRaw) + { + const auto d_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(MRaw)); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor(d_grid_desc_mraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return d_grid_desc_mraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using ReduceGridDesc_M = decltype(MakeReduceGridDescriptor_M(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + ReduceAccDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceOperations, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + InMemoryDataOperationEnum::Set, + ReduceGlobalMemoryDataOperation, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + ReduceGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + LoopSched>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + ReducePtrsGlobal p_reduces_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ReduceInElementwiseOperations reduce_in_element_ops, + ReduceAccElementwiseOperations reduce_out_element_ops) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_reduces_grid_{p_reduces_grid}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + reduce_grid_desc_m_{DeviceOp::MakeReduceGridDescriptor_M(MRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + reduce_grid_desc_mblock_mperblock_{}, + block_2_ctile_map_{GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + reduce_in_element_ops_{reduce_in_element_ops}, + reduce_out_element_ops_{reduce_out_element_ops} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + reduce_grid_desc_mblock_mperblock_ = + GridwiseGemm::MakeReduceGridDescriptor_MBlock_MPerBlock(reduce_grid_desc_m_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + ReducePtrsGlobal p_reduces_grid_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + ReduceGridDesc_M reduce_grid_desc_m_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock + reduce_grid_desc_mblock_mperblock_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + ReduceInElementwiseOperations reduce_in_element_ops_; + ReduceAccElementwiseOperations reduce_out_element_ops_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + + std::cout << "arg.reduce_grid_desc_m_{ " << arg.reduce_grid_desc_m_.GetLength(I0) + << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float elapsed_time = 0.0f; + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + true>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_reduce_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + ReducePtrsGlobal, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + ReduceInElementwiseOperations, + ReduceAccElementwiseOperations, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock, + typename GridwiseGemm::DefaultBlock2CTileMap, + false>; + + elapsed_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_reduces_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.reduce_in_element_ops_, + arg.reduce_out_element_ops_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.reduce_grid_desc_mblock_mperblock_, + arg.block_2_ctile_map_); + } + + return elapsed_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static constexpr int NumReduce = ReduceOperations::Size(); + static auto MakeArgument(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op) + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return Argument{static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const void* p_bias, + std::array p_ds, + void* p_c, + std::array p_reduces, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideB, + ck::index_t StrideC, + std::array StrideDs, + std::array gemm_element_ops, + std::array d_element_ops, + std::array reduce_in_element_op, + std::array reduce_out_element_op, + ck::index_t = 1) override + { + (void)p_bias; + (void)p_ds; + (void)StrideDs; + (void)d_element_ops; + + ReducePtrsGlobal reduce_tuple = generate_tuple( + [&](auto I) { + auto tmp = ReducePtrsGlobal{}[I]; + using T = remove_pointer_t; + return static_cast(p_reduces[I]); + }, + Number{}); + + ReduceInElementwiseOperations reduce_in_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceInElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_in_element_op[I])); + }, + Number{}); + ReduceAccElementwiseOperations reduce_out_element_ops = generate_tuple( + [&](auto I) { + auto tmp = ReduceAccElementwiseOperations{}[I]; + using T = remove_pointer_t; + return *(static_cast(reduce_out_element_op[I])); + }, + Number{}); + + AElementwiseOperation a_element_op = + *(static_cast(gemm_element_ops[0])); + BElementwiseOperation b_element_op = + *(static_cast(gemm_element_ops[1])); + CElementwiseOperation c_element_op = + *(static_cast(gemm_element_ops[2])); + + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + reduce_tuple, + M, + N, + K, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmReduce_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp new file mode 100755 index 000000000000..a7cc546f5374 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp @@ -0,0 +1,651 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmWmma_CShuffle : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false; + static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false; + + static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) && + is_same::value) + ? false + : true; + static constexpr auto BEnableLds_auto = + (MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) && + is_same::value) + ? false + : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = false; + static constexpr auto BEnableLds_manu = false; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1); + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + // Describe how data read from Global memory + static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(StrideA, I1)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else if constexpr(is_same::value) + { + const auto a_grid_desc_mraw_kraw = + make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), make_tuple(I1, StrideA)); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + }(); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_n_k = [&]() { + if constexpr(is_same::value) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(I1, StrideB)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else if constexpr(is_same_v) + { + const auto b_grid_desc_nraw_kraw = + make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), make_tuple(StrideB, I1)); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + }(); + + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + // Gridwise descriptor, mapping to whole given provblem. + using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1)); + using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemm_Wmma; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + MRaw_{M}, + NRaw_{N}, + KRaw_{K} + { + a_grid_desc_ = DeviceGemmWmma_CShuffle::MakeAGridDescriptor(M, K, StrideA); + b_grid_desc_k0_n_k1_ = DeviceGemmWmma_CShuffle::MakeBGridDescriptor(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmWmma_CShuffle::MakeCGridDescriptor_M_N(M, N, StrideC); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + // for checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmWmma_CShuffle::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_gemm_wmma< + GridwiseGemm, + ADataType, + BDataType, + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t< + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock>, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + has_main_k_block_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + printf("DeviceOp err: AccDataType"); + return false; + } + } + else + { + printf("DeviceOp err: Arch"); + return false; + } + + // check vector load/store + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // check vector load of A + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector laod of B + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + // FIXME: not rigorous + if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector store of C + // only support RowMajor for now + if constexpr(is_same_v) + { + if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmWmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp new file mode 100755 index 000000000000..a921962c6781 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,371 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock << "x" << NPerBlock << "x" << KPerBlock << ", " + << "WaveTile: " + << MPerWmma << "x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale +{ + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const BScaleDataType* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3_BScale" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3_Common +{ + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp new file mode 100755 index 000000000000..5188ece3335d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp @@ -0,0 +1,305 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdl : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + NumPrefetch, + LoopSched, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting"); + } + + const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(karg.K)) + { + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(ck::get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::is_lds_direct_load_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + if(karg.K % K1 != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemmXdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 + << ">" + << " NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp new file mode 100755 index 000000000000..dd41f4bca0a3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp @@ -0,0 +1,309 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffle : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v1; + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp new file mode 100755 index 000000000000..ac2e826725c1 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm +{ + static constexpr auto I1 = Number<1>{}; + + using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + ck::Tuple<>, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer>; + + using Argument = typename GridwiseGemm::Argument; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + typename GridwiseGemm::AGridDesc_AK0_M_AK1, + typename GridwiseGemm::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::Block2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + // arg.p_ds_grid_, + ck::Tuple<>{}, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!ck::is_lds_direct_load_supported()) + { + return false; + } + + // Check vector load/store. + { + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + // Check vector load of A. + if constexpr(is_same_v && ABlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && ABlockTransferSrcVectorDim == 1) + { + if(arg.MRaw_ % ABlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of B. + if constexpr(is_same_v && BBlockTransferSrcVectorDim == 2) + { + if(arg.KRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else if constexpr(is_same_v && BBlockTransferSrcVectorDim == 1) + { + if(arg.NRaw_ % BBlockTransferScalarPerVector != 0) + { + return false; + } + } + else + { + return false; + } + + // Check vector load of E. + // For now, only the RowMajor layout is supported. + if constexpr(is_same_v) + { + if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + return false; + } + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + using EmptyDsPointers = std::array; + using EmptyDsStrides = std::array; + + return Argument{p_a, + p_b, + EmptyDsPointers{}, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + EmptyDsStrides{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + using EmptyDsPointers = std::array; + using EmptyDsStrides = std::array; + + return std::make_unique(p_a, + p_b, + EmptyDsPointers{}, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + EmptyDsStrides{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100755 index 000000000000..3171208830c6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,811 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + float ave_time = 0; + + index_t k_grain = KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + if(arg.reduction_strategy == StreamKReductionStrategy::Atomic) + { + + hip_check_error(hipMemsetAsync( + arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_)); + } + + const auto Run = [&](const auto& kernel) { + dim3 grid_dim; + if(arg.Grid_size < 0) + { + int occupancy, num_cu; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, 0)); + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + arg.Grid_size = num_cu * occupancy; + grid_dim = arg.Grid_size; + } + else + grid_dim = arg.Grid_size; + + if(stream_config.flush_cache) + { + Argument arg_ = arg; + ck::utility::RotatingMemWrapper rotating_mem( + arg_, + stream_config.rotating_count, + arg_.M * arg_.K * sizeof(ADataType), + arg_.K * arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_); + } + else + { + + if(arg.reduction_strategy == StreamKReductionStrategy::Atomic) + { + ave_time = launch_and_time_kernel( + stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg); + } + else if(arg.reduction_strategy == StreamKReductionStrategy::Reduction) + { + char* workspace_semaphore = + reinterpret_cast(arg.p_workspace_) + + arg.block_2_ctile_map_streamk.get_workspace_size_for_acc( + sizeof(GemmAccDataType)); + auto preprocess = [&]() { + hipError_t status = hipMemsetAsync( + workspace_semaphore, + 0, + // sizeof(uint32_t), + arg.block_2_ctile_map_streamk.get_workspace_size_for_semaphore(), + stream_config.stream_id_); + + // Check the status + hip_check_error(status); + }; + + ave_time = launch_and_time_kernel_with_preprocess( + stream_config, preprocess, kernel, grid_dim, dim3(BlockSize), 0, arg); + } + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + + Run(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + else + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* p_arg = dynamic_cast(pArg); + if(p_arg->reduction_strategy == StreamKReductionStrategy::Reduction) + { + return p_arg->block_2_ctile_map_streamk.get_workspace_size(sizeof(GemmAccDataType)); + } + else + { + return 0; + } + } + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + } + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + if(!is_bf16_atomic_supported() && std::is_same_v && + arg.Streamk_sel > 0) + { + return false; + } + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) + { + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + index_t K_split = (K + KPerBlock - 1) / KPerBlock * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + int occupancy, num_cu; + const auto calculate_grid_size = [&](const auto& kernel) { + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + Grid_size = num_cu * occupancy; + }; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + calculate_grid_size(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + calculate_grid_size(kernel); + } + } + else + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + calculate_grid_size(kernel); + } + } + + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size, + reduction_strategy}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t streamk_sel, + index_t Grid_size, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + StreamKReductionStrategy reduction_strategy = StreamKReductionStrategy::Atomic) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + streamk_sel, + Grid_size, + reduction_strategy); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +template +struct DeviceGemm_Xdl_CShuffleV2 : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_CShuffleV2; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v2< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1; + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K) == 3) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v2; + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGemm_Xdl_CShuffleV2" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp new file mode 100755 index 000000000000..dde21725d0f5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp @@ -0,0 +1,887 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// @brief \"Universal\" GEMM operation with SplitK support. +/// +/// @par Overview +/// This GEMM operation implements the following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations applied to the A, B, and C tensors, respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam CDataType C tensor data type. +/// @tparam GemmAccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1 The vector load size from global memory for A tensor. +/// @tparam BK1 The vector load size from global memory for B tensor. +/// @tparam MPerXDL M size of matrix-fused-multiply-add instruction. +/// @tparam NPerXDL N size of matrix-fused-multiply-add instruction. +/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2 +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / + 4 * (1 + GridwiseGemm::NWave); + constexpr auto estimated_reg_b = + NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2); + constexpr auto estimated_reg_c = + MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4; + constexpr auto estimated_reg_total = + estimated_reg_a + estimated_reg_b + estimated_reg_c; + + constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + throw std::runtime_error("Only support pipeline ver v1, v2, v3 now!"); + } + } +#if 0 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; + Run(kernel); + } + } + } +#endif + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteA() override { return PermuteA; } + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2BScale +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const BScaleDataType* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversal" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// clang-format off +/** + * \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types + * + * This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for + * microscale-compliant data types. + * + * Assumptions: + * - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification + * - Each scale applies to ScaleBlockSize elements in K direction + * - A scale matrix is a row-major + * - B scale matrix is a column-major + * - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the + * exponent will be interpreted as conventional biased Float32 exponent (E8M0) + * + * Tunable parameters. + * The CK instance includes a series of tunable template parameters to control the parallel + * granularity of the workload to achieve load balancing on different hardware platforms. These + * parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc. + * - Block Size determines the number of threads in the thread block. + * - M/N/K Per Block determines the size of tile that each thread block is responsible for + * calculating. + * - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA) + * instructions operating on a per-wavefront basis. + * - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To + * achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B + * loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will + * result in compilation errors. + * + * Conditions for achieving computational load balancing on different hardware platforms can vary. + * + * \tparam KPerBlock is the number of elements in K dimension that each block processes (multiply with packed_size_v to get the actual KPerBlock) + * + * Serialized version of the algorithm: + * \code + * // E = A * B + C + * // Loop over E[MPerBlock,NPerBlock] tiles + * for(int mb = 0; mb < M; mb += MPerBlock){ + * for(int nb = 0; nb < N; nb += NPerBlock){ + * // initialize E[MPerBlock,NPerBlock] tile + * for(int mt = mb; mt < mb + MPerBlock; mt++){ + * for(int nt = nb; nt < nb + NPerBlock; nt++){ + * E[mt,nt] = C[mt,nt]; + * } + * } + * + * // multiply-accumulate per tile + * for(int kb = 0; kb < K; kb += KPerBlock){ + * for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){ + * for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){ + * for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){ + * for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){ + * for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){ + * // MFMA accumulation + * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){ + * // MFMA instruction + * for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){ + * for(int m = mw; m < mw + MPerXDL; m++){ + * for(int n = nw; n < nw + NPerXDL; n++){ + * for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){ + * E[m,n] += A[m,k] * B[k,n]; + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * \endcode + * + */ +// clang-format on +template +struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX +{ + // GridwiseGemm + using GridwiseGemm = conditional_t< // + !is_same_v, + GridwiseGemmMX_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>, + GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // TODO: Check if this is the right algorithm for minimum_occupancy + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + constexpr auto TailNumChoices = []() { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + return Tuple>{}; + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + return Tuple, constant>{}; + else + static_assert(false, "Unexpected BlkGemmPipelineVer!"); + }(); + constexpr bool Use2LDS = []() { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + return false; + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + return true; + else + static_assert(false, "Unexpected BlkGemmPipelineVer!"); + }(); + const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split); + using BoolChoices = Tuple; + static_for_product>{}( + [&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) { + constexpr auto CGlobalMemoryDataOperation = + KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd + : InMemoryDataOperationEnum::Set; + if(mainloop_choice.value == has_main_k_block_loop && + KBatch_cond_choice.value == (arg.KBatch > 1) && + tail_num_choice.value == tail_num) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< // + Use2LDS, + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; + Run(kernel); + } + }); + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + static_assert(is_scale_mfma_data_type() && is_scale_mfma_data_type(), + "Only microscaling formats are supported for ADataType and BDataType"); + + static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); + + static_assert(is_same_v && is_same_v, + "ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType"); + + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(!IsValidCompilationParameter()) + { + return false; + } + + if(ck::get_device_name() != "gfx950") + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const AScaleDataType* p_a_scale, + const BDataType* p_b, + const BScaleDataType* p_b_scale, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_a_scale, + p_b, + p_b_scale, + p_c, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideScaleA, + ck::index_t StrideB, + ck::index_t StrideScaleB, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMX_Xdl_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_CShuffleV3R1 : public DeviceGemmV2R1 +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + ReduceDataType, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + const std::array p_ds_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + reinterpret_cast(p_c_grid_), + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + true), + p_ds(p_ds_), + StrideDs(StrideDs_) + { + } + + const std::array p_ds; + std::array StrideDs; + }; + + using ReduceAdd = ck::reduce::Add; + using OutElementwiseOperation = CElementwiseOperation; + + static constexpr auto DsVectorLengthSequence = generate_sequence_v2( + [](auto i) { + using DLayout = remove_cvref_t>; + if constexpr(std::is_same::value) + return Number{}; + else + return Number<1>{}; + }, + Number{}); + + using DeviceReduceInstance = DeviceReduceThreadWiseMultiD< + ReduceDataType, // InDataType, + DsDataType, // DsDatatype + GemmAccDataType, // AccDataType, + CDataType, // OutDataType, + 3, // Rank + 1, // NumReduceDim + ReduceAdd, + PassThrough, + OutElementwiseOperation, + 256, // BlockSize_, + CShuffleBlockTransferScalarPerVector_NPerBlock, // MThreadSliceSize_, + 1, // KThreadSliceSize_, + 0, // InSrcVectorDim_, + CShuffleBlockTransferScalarPerVector_NPerBlock, // InSrcVectorSize_, + CShuffleBlockTransferScalarPerVector_NPerBlock, // OutDstVectorSize_ + decltype(DsVectorLengthSequence)>; + + // Invoker + struct Invoker : public BaseInvoker + { + float RunReduce(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + static constexpr index_t NumInDim = 3; + static constexpr index_t NumOutDim = 2; + + std::array in_lengths = {arg.KBatch, arg.M, arg.N}; + std::array out_lengths = {arg.M, arg.N}; + + std::array in_strides; + std::array out_strides; + if constexpr(std::is_same::value) + { + in_strides = {arg.M * arg.N, arg.N, 1}; + out_strides = {arg.N, 1}; + } + else + { + in_strides = {arg.M * arg.N, 1, arg.M}; + out_strides = {1, arg.M}; + } + + std::array reduce_dims{0}; + + std::array, NumDTensor> DsLengths; + std::array, NumDTensor> DsStrides; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + DsLengths[i] = out_lengths; + + using DLayout = remove_cvref_t>; + if constexpr(std::is_same::value) + { + DsStrides[i] = {arg.StrideDs[i], 1}; + } + else + { + DsStrides[i] = {1, arg.StrideDs[i]}; + } + }); + + auto reduce = DeviceReduceInstance{}; + + auto argument_ptr = reduce.MakeArgumentPointer(in_lengths, + in_strides, + DsLengths, + DsStrides, + out_lengths, + out_strides, + reduce_dims, + arg.p_workspace_, + arg.p_ds, + arg.p_c_grid, + PassThrough{}, + OutElementwiseOperation{}); + + auto invoker_ptr = reduce.MakeInvokerPointer(); + + float ave_time = 0; + + if(reduce.IsSupportedArgument(argument_ptr.get())) + { + ave_time = invoker_ptr->Run(argument_ptr.get(), stream_config); + } + else + { + throw std::runtime_error( + "The runtime parameters seems not supported by the device instance, exiting!"); + } + + return ave_time; + } + + float Run(const Argument& arg_, const StreamConfig& stream_config = StreamConfig{}) + { + auto arg = *dynamic_cast(&arg_); + + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && + std::is_same::value)) + { + if(arg.p_workspace_ == nullptr) + { + throw std::runtime_error("using reduce , but empty workspace!"); + } + + arg.p_c_grid = reinterpret_cast(arg.p_workspace_); + } + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + ck::utility::RotatingMemWrapper rotating_mem( + arg, + stream_config.rotating_count, + arg.M * arg.K * sizeof(ADataType), + arg.K * arg.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg); + } + else + { + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && + std::is_same::value)) + { + // reduce c data + ave_time += RunReduce(arg_, stream_config); + } + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + const std::array p_ds, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) + { + return Argument{p_a, p_b, p_ds, p_c, M, N, K, StrideA, StrideB, StrideDs, StrideC, KBatch}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmXdlUniversalReduce" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"<(p_arg); + + if(!(!(arg.IsReduceAdd() || NumDTensor > 0) && + std::is_same::value)) + { + std::cout << "using workspace" << std::endl; + return arg.M * arg.N * arg.KBatch * sizeof(ReduceDataType); + } + + return 0; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp new file mode 100755 index 000000000000..213501468a3b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp @@ -0,0 +1,781 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +// +// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle +// version currently has compiler issues with register spill which further causes validation +// failures. +// +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator +{ + using DeviceOp = DeviceGemmLayerNorm_Xdl_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + assert(K % AK1 == 0); + + const auto AK0 = K / AK1; + + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + assert(KRaw % AK1 == 0); + + const auto AK0 = KRaw / AK1; + + const auto a_grid_desc_ak0_m_ak1 = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto K = math::integer_divide_ceil(KRaw, KPerBlock) * KPerBlock; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + assert(K % BK1 == 0); + + const auto BK0 = K / BK1; + + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + assert(KRaw % BK1 == 0); + + const auto BK0 = KRaw / BK1; + + const auto b_grid_desc_bk0_n_bk1 = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideC)); + } + }(); + + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + static auto MakeGridDescriptor_N(index_t NRaw) + { + const auto grid_desc_nraw = make_naive_tensor_descriptor_packed(make_tuple(NRaw)); + + const auto N = math::integer_divide_ceil(NRaw, NPerBlock) * NPerBlock; + const auto NPad = N - NRaw; + + if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad N + return transform_tensor_descriptor(grid_desc_nraw, + make_tuple(make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad N + return grid_desc_nraw; + } + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + using C0GridDesc_N = decltype(MakeGridDescriptor_N(1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + C0DataType, + ReduceAccDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + CGridDesc_M_N, + C0GridDesc_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + LoopSched>; + + using Block2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const C0DataType* p_c0_grid_add, + const C0DataType* p_c0_grid_bias, + const C0DataType* p_c0_grid_gamma, + const C0DataType* p_c0_grid_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + p_c0_grid_bias_{p_c0_grid_bias}, + p_c0_grid_add_{p_c0_grid_add}, + p_c0_grid_gamma_{p_c0_grid_gamma}, + p_c0_grid_beta_{p_c0_grid_beta}, + a_grid_desc_ak0_m_ak1_{DeviceOp::MakeAGridDescriptor_AK0_M_AK1(MRaw, KRaw, StrideA)}, + b_grid_desc_bk0_n_bk1_{DeviceOp::MakeBGridDescriptor_BK0_N_BK1(KRaw, NRaw, StrideB)}, + c_grid_desc_m_n_{DeviceOp::MakeCGridDescriptor_M_N(MRaw, NRaw, StrideC)}, + c0_grid_desc_n_{MakeGridDescriptor_N(NRaw)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + c0_grid_desc_nblock_nperblock_{}, + block_2_ctile_map_{Block2CTileMap(c_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + c_element_op_{c_element_op} + { + if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, + b_grid_desc_bk0_n_bk1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + + c0_grid_desc_nblock_nperblock_ = + GridwiseGemm::MakeC0GridDescriptor_NBlock_NPerBlock(c0_grid_desc_n_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + const C0DataType* p_c0_grid_bias_; + const C0DataType* p_c0_grid_add_; + const C0DataType* p_c0_grid_gamma_; + const C0DataType* p_c0_grid_beta_; + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + C0GridDesc_N c0_grid_desc_n_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock_; + Block2CTileMap block_2_ctile_map_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + float ave_time = 0; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + true>; + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_layernorm_xdl_cshuffle_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + C0DataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + CElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::C0GridDescriptor_NBlock_NPerBlock, + Block2CTileMap, + false>; + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.p_c0_grid_bias_, + arg.p_c0_grid_add_, + arg.p_c0_grid_gamma_, + arg.p_c0_grid_beta_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.c0_grid_desc_nblock_nperblock_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const C0DataType* p_c0_bias, + const C0DataType* p_c0_add, + const C0DataType* p_c0_gamma, + const C0DataType* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + p_c0_bias, + p_c0_add, + p_c0_gamma, + p_c0_beta, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const void* p_c0_bias, + const void* p_c0_add, + const void* p_c0_gamma, + const void* p_c0_beta, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + CElementwiseOperation c_element_op, + index_t /* KBatch */ = 1) + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + static_cast(p_c0_bias), + static_cast(p_c0_add), + static_cast(p_c0_gamma), + static_cast(p_c0_beta), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideC, + a_element_op, + b_element_op, + acc_element_op, + c_element_op); + } + + std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmLayerNorm_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp new file mode 100755 index 000000000000..7315fe75a3f4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_skip_b_lds.hpp @@ -0,0 +1,530 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlSkipBLds : public DeviceGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto K1Number = Number{}; + static_assert(BBlockBufferSize >= 2); + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1< + BlockSize, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + InMemoryDataOperationEnum::Set, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferSrcScalarPerVector, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockBufferSize, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t M01, + index_t N01, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b_grid_{p_b_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, + block_2_ctile_map_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + a_grid_desc_k0_m_k1_ = + DeviceGemmXdlSkipBLds::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + b_grid_desc_k0_n_k1_ = + DeviceGemmXdlSkipBLds::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + c_grid_desc_m_n_ = DeviceGemmXdlSkipBLds::MakeCGridDescriptor_M_N(M, N, StrideC); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) + { + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = + GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); + + block_2_ctile_map_ = + GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); + + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_ = + GridwiseGemm::MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(b_grid_desc_k0_n_k1_); + } + } + + // private: + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + typename GridwiseGemm::BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_; + typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; + typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; + index_t M01_; + index_t N01_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceGemmXdlSkipBLds::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); + + const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; + + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + else + { + const auto kernel = kernel_gemm_xdlops_skip_b_lds_v1< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_, + arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.block_2_ctile_map_); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.M01_, + arg.N01_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + 1, + 1, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemmXdlSkipBLds" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp new file mode 100755 index 000000000000..2666051c8607 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + +struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // TODO: should be exposed as Tparams. + static constexpr index_t NumGemmKPrefetchStage = 1; + + using ComputeTypeA = ComputeType; + using ComputeTypeB = ComputeType; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Invoker + struct Invoker : public BaseInvoker + { + + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + + const auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0Padded = karg.K0Padded; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(kbatch > 1) + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + static_cast(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; + + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; + + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdlops_v2r4r2_simplified; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + str << GridwiseGemm::GetTypeString() << " LoopScheduler: " << LoopSchedToString[LoopSched] + << ", PipelineVersion: " << PipelineVersionToString[PipelineVer]; + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp new file mode 100755 index 000000000000..eda966c48a21 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp @@ -0,0 +1,425 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + +struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_xdlops_splitk_lds_direct_load< + BlockSize, + ADataType, + BDataType, + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferScalarPerVector, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferScalarPerVector, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeType>; + + struct Argument : public GridwiseGemm::Argument + { + Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : GridwiseGemm::Argument(p_a_grid_, + p_b_grid_, + p_c_grid_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + MPadded_, + NPadded_, + KPadded_, + K0Padded_, + k_batch_), + a_element_op(a_element_op_), + b_element_op(b_element_op_), + c_element_op(c_element_op_) + { + } + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + + using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; + + // Invoker + struct Invoker : public BaseInvoker + { + + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + + const auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + const auto b2c_map = DefaultBlock2CTileMap{}; + index_t gdx, gdy, gdz; + ck::tie(gdx, gdy, gdz) = b2c_map.CalculateGridSize(karg.M, karg.N, karg.k_batch); + const auto K0Padded = karg.K0Padded; + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0Padded); + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(kbatch > 1) + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + static_cast(karg), + b2c_map, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); + }; + + if(has_main_k0_block_loop) + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + else + { + if(kbatch == 1) + { + const auto kernel = + kernel_gemm_xdlops_splitk_lds_direct_load; + + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdlops_splitk_lds_direct_load< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + DefaultBlock2CTileMap, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation>; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) + { + return Argument(p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + ck::index_t KBatch = 1) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + GridwiseGemm::CalculateMPadded(M), + GridwiseGemm::CalculateNPadded(N), + GridwiseGemm::CalculateKPadded(K, KBatch), + GridwiseGemm::CalculateK0Padded(K, KBatch), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{ + {PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}}; + + // clang-format off + str << "DeviceGemmXdlSplitKCShuffle_LdsDirectLoad" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferScalarPerVector << ", " + << BBlockTransferScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">" + << " LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer] << ", " + << "Prefetch: " + << NumGemmKPrefetchStage; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp new file mode 100755 index 000000000000..51b8958d61fe --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp @@ -0,0 +1,360 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmXdlStreamK : public DeviceGemmStreamK +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk< + BlockSize, + BlockToCTileMap_GemmStreamK, + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CDataType, + ALayout, + BLayout, + CLayout, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXDL, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + void Print(const Argument& karg) { karg.Print(); } + + float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(karg); + } + if(!GridwiseGemm::CheckValidity(karg)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid " + "setting"); + } + + dim3 grid_dims = karg.block_mapping.get_grid_dims(); + + float ave_time = 0; + + const auto kernel = kernel_gemm_xdlops_streamk; + + // TODO: remove clear buffer for streamk kernels + if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + hipGetErrorString(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(CDataType), + stream_config.stream_id_)); + ave_time = launch_and_time_kernel(stream_config, + kernel, + grid_dims, + dim3(BlockSize), + 0, + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + karg.p_workspace_, + karg.M, + karg.N, + karg.K, + karg.StrideA, + karg.StrideB, + karg.StrideC, + karg.block_mapping); + } + else if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + char* workspace_semaphore = reinterpret_cast(karg.p_workspace_) + + karg.block_mapping.get_workspace_size_for_acc( + sizeof(typename GridwiseGemm::FloatAcc)); + auto preprocess = [&]() { + hipGetErrorString( + hipMemsetAsync(workspace_semaphore, + 0, + karg.block_mapping.get_workspace_size_for_semaphore(), + stream_config.stream_id_)); + }; + + ave_time = launch_and_time_kernel_with_preprocess(stream_config, + preprocess, + kernel, + grid_dims, + dim3(BlockSize), + 0, + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + karg.p_workspace_, + karg.M, + karg.N, + karg.K, + karg.StrideA, + karg.StrideB, + karg.StrideC, + karg.block_mapping); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* p_arg = dynamic_cast(pArg); + if constexpr(GridwiseGemm::Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + return p_arg->block_mapping.get_workspace_size(sizeof(typename GridwiseGemm::FloatAcc)); + } + else + { + return 0; + } + } + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + } + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& karg) + { + if(!(ck::is_xdl_supported())) + { + return false; + } + return GridwiseGemm::CheckValidity(karg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + uint32_t NumSKBlocks = 0xffffffff) + { + const auto kernel = kernel_gemm_xdlops_streamk; + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + static_cast(num_cu), + static_cast(occupancy), + NumSKBlocks}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + index_t NumSKBlocks = 0) override + { + const auto kernel = kernel_gemm_xdlops_streamk; + int occupancy, num_cu; + hipError_t rtn; + rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occupancy, kernel, BlockSize, GridwiseGemm::GetSharedMemoryNumberOfByte()); + hip_check_error(rtn); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + rtn = hipGetDevice(&dev); + hip_check_error(rtn); + rtn = hipGetDeviceProperties(&dev_prop, dev); + hip_check_error(rtn); + num_cu = dev_prop.multiProcessorCount; + + return std::make_unique(reinterpret_cast(p_a), + reinterpret_cast(p_b), + reinterpret_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + static_cast(num_cu), + static_cast(occupancy), + static_cast(NumSKBlocks)); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override { return GridwiseGemm::GetTypeString(); } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp new file mode 100755 index 000000000000..2554ffea46b0 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp @@ -0,0 +1,526 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_WAVELET_MAX_THREAD_PER_BLOCK, CK_WAVELET_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_waveletmodel_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const EElementwiseOperation e_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + e_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = e_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm +{ + using DeviceOp = DeviceGemm_Xdl_WaveletModel_CShuffle; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAcEDataType, + CShuffleDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_M_K, + BGridDesc_N_K, + EGridDesc_M_N, + NumGemmKPrefetchStage, + TileLoadThreadGroupSize, + TileMathThreadGroupSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock>; + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const ADataType* p_a_grid, + const BDataType* p_b_grid, + EDataType* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + if(GridwiseGemm::CheckValidity( + a_grid_desc_m_k_, b_grid_desc_n_k_, e_grid_desc_m_n_, block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { +#if 0 + { + std::cout << "arg.a_grid_desc_ak0_m_ak1_{" + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", " + << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_bk0_n_bk1_{" + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", " + << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } +#endif + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_); + const auto K = arg.a_grid_desc_m_k_.GetLength(I1); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_gemm_xdl_waveletmodel_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + EDataType* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a, + p_b, + p_e, + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideE, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_e, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_e), + MRaw, + NRaw, + KRaw, + StrideA, + StrideB, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGemm_Xdl_WaveletModel_CShuffle" + << "<" + << TileLoadThreadGroupSize << ", " + << TileMathThreadGroupSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..884175eacae4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,915 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_contraction_multiple_d_xdl_cshuffle( + const void CK_CONSTANT_ADDRESS_SPACE* contraction_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto contraction_arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(contraction_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while((!(block_id >= contraction_arg_ptr[group_id].block_start_ && + block_id < contraction_arg_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < contraction_arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + contraction_arg_ptr[group_id].p_a_grid_, + contraction_arg_ptr[group_id].p_b_grid_, + contraction_arg_ptr[group_id].p_ds_grid_, + contraction_arg_ptr[group_id].p_e_grid_, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + contraction_arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + contraction_arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + contraction_arg_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + contraction_arg_ptr[group_id].block_2_etile_map_); +#else + ignore = contraction_args; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[M0, M1, M2, ..., K0, K1, K2, ...] +// B[N0, N1, N2, ..., K0, K1, K2, ...] +// D[M0, M1, M2, ..., N0, N1, N2, ...] +// E[M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceGroupedContractionMultipleD_Xdl_CShuffle + : public DeviceGroupedContractionMultipleD +{ + using DeviceOp = DeviceGroupedContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_ms_ks_lengths_vec, + const std::vector& a_ms_ks_strides_vec) + { + assert(a_ms_ks_lengths_vec.size() == NumDimM + NumDimK && + a_ms_ks_strides_vec.size() == NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number{}); + const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_ns_ks_lengths_vec, + const std::vector& b_ns_ks_strides_vec) + { + assert(b_ns_ks_lengths_vec.size() == NumDimN + NumDimK && + b_ns_ks_strides_vec.size() == NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto b_ns_ks_lengths = to_tuple(b_ns_ks_lengths_vec, Number{}); + const auto b_ns_ks_strides = to_tuple(b_ns_ks_strides_vec, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_ms_ns_lengths_vec, + const std::vector& e_ms_ns_strides_vec) + { + assert(e_ms_ns_lengths_vec.size() == NumDimM + NumDimN && + e_ms_ns_strides_vec.size() == NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto num) { + return generate_tuple([&](auto i) { return vec[i]; }, num); + }; + + const auto e_ms_ns_lengths = to_tuple(e_ms_ns_lengths_vec, Number{}); + const auto e_ms_ns_strides = to_tuple(e_ms_ns_strides_vec, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_ms_ns_lengths_vec[i], + ds_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using ComputeDataType = ADataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + struct GroupedContractionBlock2ETileMap + { + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + + GroupedContractionBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, + ck::index_t BlockStart) + { + default_block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + block_start_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return default_block_2_etile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - block_start_)); + } + + // it's actually E-Tile + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return default_block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const + { + return default_block_2_etile_map_.CheckValidity(e_grid_desc_m_n); + } + + Block2ETileMap default_block_2_etile_map_; + ck::index_t block_start_; + }; + + struct ContractionMultiDKernelArg + { + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // lock-to-e-tile map + GroupedContractionBlock2ETileMap block_2_etile_map_; + + ck::index_t block_start_, block_end_; + }; + + struct ContractionMultiDDeviceArg + { + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + // index_t e_mz_stride_; + index_t e_nz_stride_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + group_count_ = contraction_descs.size(); + + if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && + group_count_ == p_e_vec.size())) + { + throw std::runtime_error("wrong! group_count_ != a/b/e_vec.size"); + } + + contraction_multi_d_kernel_args_.reserve(group_count_); + + grid_size_ = 0; + + for(std::size_t i = 0; i < group_count_; i++) + { + const auto p_a_grid = static_cast(p_a_vec[i]); + const auto p_b_grid = static_cast(p_b_vec[i]); + const auto p_e_grid = static_cast(p_e_vec[i]); + + const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K( + contraction_descs[i].a_ms_ks_lengths, contraction_descs[i].a_ms_ks_strides); + const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K( + contraction_descs[i].b_ns_ks_lengths, contraction_descs[i].b_ns_ks_strides); + + DsGridDesc_M_N ds_grid_desc_m_n; + typename GridwiseGemm::DsGridPointer p_ds_grid; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid(j) = static_cast(p_ds_vec[i][j]); + + // D desc + ds_grid_desc_m_n(j) = + DeviceOp::MakeEGridDescriptor_M_N(contraction_descs[i].ds_ms_ns_lengths[j], + contraction_descs[i].ds_ms_ns_strides[j]); + }); + + const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N( + contraction_descs[i].e_ms_ns_lengths, contraction_descs[i].e_ms_ns_strides); + + const auto a_grid_desc_ak0_m_ak1 = + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + const auto b_grid_desc_bk0_n_bk1 = + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n); + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + + const index_t grid_size_grp = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n) + .CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + const auto block_2_etile_map = + GroupedContractionBlock2ETileMap(e_grid_desc_m_n, BlockStart); + + // for sanity check of vector memory access + const index_t a_mz_stride = contraction_descs[i].a_ms_ks_strides[NumDimM - 1]; + const index_t a_kz_stride = + contraction_descs[i].a_ms_ks_strides[NumDimM + NumDimK - 1]; + + const index_t b_nz_stride = contraction_descs[i].b_ns_ks_strides[NumDimN - 1]; + const index_t b_kz_stride = + contraction_descs[i].b_ns_ks_strides[NumDimN + NumDimK - 1]; + + std::array ds_nz_stride; + for(index_t j = 0; j < NumDTensor; ++j) + { + ds_nz_stride[j] = + contraction_descs[i].ds_ms_ns_strides[j][NumDimM + NumDimN - 1]; + } + + const index_t e_nz_stride = + contraction_descs[i].e_ms_ns_strides[NumDimM + NumDimN - 1]; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + contraction_multi_d_kernel_args_.push_back( + {p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map, + BlockStart, + BlockEnd}); + + contraction_multi_d_device_args_.push_back({a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_mz_stride, + a_kz_stride, + b_nz_stride, + b_kz_stride, + ds_nz_stride, + e_nz_stride}); + } + } + } + + std::vector contraction_multi_d_kernel_args_; + std::vector contraction_multi_d_device_args_; + + std::size_t group_count_; + index_t grid_size_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto K = + arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + hipGetErrorString(hipMemcpyWithStream(arg.p_workspace_, + arg.contraction_multi_d_kernel_args_.data(), + arg.contraction_multi_d_kernel_args_.size() * + sizeof(ContractionMultiDKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = + kernel_grouped_contraction_multiple_d_xdl_cshuffle; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + }; + + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto a_grid_desc_m_k_ = arg.contraction_multi_d_device_args_[i].a_grid_desc_m_k_; + const auto b_grid_desc_n_k_ = arg.contraction_multi_d_device_args_[i].b_grid_desc_n_k_; + const auto ds_grid_desc_m_n_ = + arg.contraction_multi_d_device_args_[i].ds_grid_desc_m_n_; + const auto e_grid_desc_m_n_ = arg.contraction_multi_d_device_args_[i].e_grid_desc_m_n_; + const auto a_grid_desc_ak0_m_ak1_ = + arg.contraction_multi_d_kernel_args_[i].a_grid_desc_ak0_m_ak1_; + const auto b_grid_desc_bk0_n_bk1_ = + arg.contraction_multi_d_kernel_args_[i].b_grid_desc_bk0_n_bk1_; + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + arg.contraction_multi_d_kernel_args_[i] + .ds_grid_desc_mblock_mperblock_nblock_nperblock_; + const auto e_grid_desc_mblock_mperblock_nblock_nperblock_ = + arg.contraction_multi_d_kernel_args_[i] + .e_grid_desc_mblock_mperblock_nblock_nperblock_; + + const auto block_2_etile_map_ = + arg.contraction_multi_d_kernel_args_[i].block_2_etile_map_; + + const auto a_mz_stride_ = arg.contraction_multi_d_device_args_[i].a_mz_stride_; + const auto a_kz_stride_ = arg.contraction_multi_d_device_args_[i].a_kz_stride_; + const auto b_nz_stride_ = arg.contraction_multi_d_device_args_[i].b_nz_stride_; + const auto b_kz_stride_ = arg.contraction_multi_d_device_args_[i].b_kz_stride_; + const auto ds_nz_stride_ = arg.contraction_multi_d_device_args_[i].ds_nz_stride_; + const auto e_nz_stride_ = arg.contraction_multi_d_device_args_[i].e_nz_stride_; + + if(!GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) && + (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 1) + { + if(!(a_mz_stride_ == 1 && + a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(a_kz_stride_ == 1 && + a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 1) + { + if(!(b_nz_stride_ == 1 && + b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + if(!(b_kz_stride_ == 1 && + b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + if(!(ds_nz_stride_[j] == 1 && + ds_grid_desc_mblock_mperblock_nblock_nperblock_[j].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!(e_nz_stride_ == 1 && e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + return false; + } + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{p_a_vec, + p_b_vec, + p_ds_vec, + p_e_vec, + contraction_descs, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector> p_ds_vec, + std::vector p_e_vec, + std::vector> contraction_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique(p_a_vec, + p_b_vec, + p_ds_vec, + p_e_vec, + contraction_descs, + a_element_op, + b_element_op, + cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * + sizeof(ContractionMultiDKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp new file mode 100755 index 000000000000..651e730b63e6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,833 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t KPerBlock = K0PerBlock * K1; + + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + const auto ds_grid_desc_m_n = + generate_tuple([&](auto) { return conv_to_gemm_transform.MakeCDescriptor_M_N(); }, + Number{}); + const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + + // desc + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + KPerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + true, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + /*ds_g_n_c_wis_lengths*/, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal to the we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + + ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + const auto a_grid_desc_ak0_m_ak1 = + conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = + conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + static_assert(is_same_v); + ConvToGemmBwdDataTransform conv_to_gemm_transform_d{ + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N(); + }); + + const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N(); + + // for check validity + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + // desc for blockwise copy + a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); + b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); + + // block-to-e-tile-map + auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap( + e_grid_desc_m_n, 1 /* M01 */, 1 /* N01 */); + + block_2_ctile_map_container_.push_back(block_2_ctile_map); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n)); + } + } + } + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + std::cout << "a_grid_desc_ak0_m_ak1_container_" + << a_grid_desc_ak0_m_ak1_container_[i] << std::endl; + + std::cout << "b_grid_desc_bk0_n_bk1_container_" + << b_grid_desc_bk0_n_bk1_container_[i] << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j] + << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + std::vector + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // block-to-e-tile map + std::vector block_2_ctile_map_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + std::array a_g_n_k_wos_lengths_; + std::array b_g_k_c_xs_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + + const index_t k_batch_; + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize( + arg.e_grid_desc_m_n_container_[i]) * + arg.num_group_; + + const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && i == 0) + { + hip_check_error(hipMemsetAsync( + arg.p_e_grid_, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_k_wos_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], + arg.block_2_ctile_map_container_[i], + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + { + ave_time += launch_kernel(integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}); + } + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(arg.k_batch_ != 1) + { + return false; + } + + // check device + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector load D matrix from global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_container_[i], + arg.b_grid_desc_bk0_n_bk1_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_ctile_map_container_[i])) + { + return false; + } + } + + // check number of dimension, only implemented for 2D and 3D now + if(NDimSpatial != 2 && NDimSpatial != 3) + { + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..db2426518a1a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -0,0 +1,1830 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + * implementation we can avoid copy data to workspace before kernel launch since number of groups is + * runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then we run this + * kernel in the loop. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const std::array gemm_kernel_args, + const index_t gemms_count, + const AElementwiseOp a_element_op, + const BElementwiseOp b_element_op, + const CDEElementwiseOp cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); + + const long_index_t a_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + const long_index_t a_n_offset = + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsPointer::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + if constexpr(HasMainKBlockLoopInAllGemm || NoMainKBlockLoopInAllGemm) + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + else + { + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + else + { + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_kernel_args[group_id].block_2_ctile_map_, + KBatch, + k_idx); + } + } +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = gemm_kernel_args; + ignore = gemms_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = compute_ptr_offset_of_n; +#endif +} + +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + // implementation we can avoid copy data to workspace before kernel launch since number of + // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then + // we run this kernel in the loop. + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = + (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && + std::is_same_v, element_wise::PassThrough>; + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr bool isATensorColMajor = + (ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = + (NeedTransposeKernel == false) && (is_same_v || + is_same_v); + + using ALayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGK, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGK, + ALayout>>; + using BLayoutAfterTranspose = std::conditional_t< + is_NGCHW_GKCYX_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::GKYXC, + std::conditional_t() && + NeedTransposeKernel, + tensor_layout::convolution::GKZYXC, + BLayout>>; + using ELayoutAfterTranspose = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + tensor_layout::convolution::NHWGC, + std::conditional_t() && NeedTransposeKernel, + tensor_layout::convolution::NDHWGC, + ELayout>>; + + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + + const auto ds_grid_desc_m_n = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + return ConvToGemmBwdDataTransformD{}.MakeCDescriptor_M_N(); + }, + Number{}); + + const auto e_grid_desc_m_n = conv_to_gemm_transform.MakeCDescriptor_M_N(); + if constexpr(CTranspose) + { + return make_tuple( + b_grid_desc_bk0_n_bk1, a_grid_desc_ak0_m_ak1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + else + { + return make_tuple( + a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n); + } + } + +// GridwiseGemm +#define GridwiseGemmMultiDTemplateParams \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + +#define GridwiseGemmCTransposeTemplateParameters \ + ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + BElementwiseOp, AElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, MPerXDL, NXdlPerWave, MXdlPerWave, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType + + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + using GridwiseGemmCTranspose = std::conditional_t< + CTranspose, + GridwiseGemmMultipleD_xdl_cshuffle, + GridwiseGemm>; + + template + static auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) + { + return GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + } + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // desc + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using DsGridDesc_M_N = remove_cvref_t>; + using EGridDesc_M_N = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{})); + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})); + + // block-to-e-tile map + using Block2ETileMap = + decltype(GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(EGridDesc_M_N{})); + + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + + struct GemmArgs + { + GemmArgs() = default; + GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + GroupedGemmBlock2ETileMap block_2_ctile_map, + index_t BlockStart, + index_t BlockEnd, + bool HasMainKBlockLoop) + : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), + b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + + ds_grid_desc_mblock_mperblock_nblock_nperblock_( + ds_grid_desc_mblock_mperblock_nblock_nperblock), + + e_grid_desc_mblock_mperblock_nblock_nperblock_( + e_grid_desc_mblock_mperblock_nblock_nperblock), + + // block-to-e-tile map + block_2_ctile_map_(block_2_ctile_map), + BlockStart_(BlockStart), + BlockEnd_(BlockEnd), + HasMainKBlockLoop_(HasMainKBlockLoop) + + { + } + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t BlockStart_, BlockEnd_; + bool HasMainKBlockLoop_; + }; + using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + static constexpr index_t ClusterLengthMPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr index_t TransposeTransferInScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferInScalarPerVector); + static constexpr index_t TransposeTransferOutScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferOutScalarPerVector); + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock; + + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + MPerBlock, + NPerBlock / ClusterLengthNPerBlock, + MPerBlock / ClusterLengthMPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + MPerBlock, + NPerBlock / ClusterLengthNPerBlock, + MPerBlock / ClusterLengthMPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal to the we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bool if_d_is_output_mem = false; + const void* out_mem_void = static_cast(p_e); + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(p_ds[i] == out_mem_void) + { + if_d_is_output_mem = true; + } + }); + + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + + // Temporary workaround untill prove/fix above conditions. + bwd_needs_zero_out = !if_d_is_output_mem; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + + std::array a_g_n_k_wos_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides) + : a_g_n_k_wos_strides; + std::array b_g_k_c_xs_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides; + std::array e_g_n_c_wis_strides_transposed = + NeedTransposeKernel ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides) + : e_g_n_c_wis_strides; + + // populate Ds pointer + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + p_ds_grid_(i) = static_cast(p_ds[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0]; + }); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + index_t grid_size = 0; + // Allocate place for sets of gemms + gemm_kernel_args_.resize( + math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum)); + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + ConvToGemmBwdDataTransform conv_to_gemm_transform_{ + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides_transposed, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides_transposed, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes, + k_batch_}; + + conv_N_per_block_ = conv_to_gemm_transform_.N_; + + const auto a_grid_desc_ak0_m_ak1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + else + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + }(); + + const auto b_grid_desc_bk0_n_bk1 = [&]() { + if constexpr(CTranspose) + { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + } + else + { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + } + }(); + DsGridDesc_M_N ds_grid_desc_m_n; + + // populate Ds desc + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + using ConvToGemmBwdDataTransformD = + TransformConvBwdDataToGemm_v1; + ConvToGemmBwdDataTransformD conv_to_gemm_transform_d{ + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides_transposed, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides_transposed, + ds_g_n_c_wis_lengths[i], + ds_g_n_c_wis_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes}; + + ds_grid_desc_m_n(i) = conv_to_gemm_transform_d.MakeCDescriptor_M_N(); + }); + + const auto e_grid_desc_m_n = conv_to_gemm_transform_.MakeCDescriptor_M_N(); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); + + const index_t grid_size_grp = Block2ETileMap::CalculateGridSize( + e_grid_desc_m_n.GetLength(I0), e_grid_desc_m_n.GetLength(I1)); + + const index_t BlockStart = grid_size; + const index_t BlockEnd = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + // block-to-e-tile map + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(Block2ETileMap(e_grid_desc_m_n.GetLength(I0), + e_grid_desc_m_n.GetLength(I1)), + BlockStart); + + const auto GemmK = a_grid_desc_m_k.GetLength(I1); + const bool HasMainKBlockLoop = + GridwiseGemmCTranspose::CalculateHasMainKBlockLoop(GemmK, k_batch_); + + gemm_kernel_args_[gemms_count_ / + MaxGroupedGemmGroupsNum][gemms_count_ % + MaxGroupedGemmGroupsNum] = + GemmArgs{a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + GridwiseGemmCTranspose:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n), + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd, + HasMainKBlockLoop}; + gemms_count_++; + if(gemms_count_ % MaxGroupedGemmGroupsNum == 0) + { + gemms_grid_size_.push_back(grid_size); + grid_size = 0; + } + } + } + } + gemm_kernel_args_.resize( + math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum)); + gemms_grid_size_.push_back(grid_size); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides_transposed[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_k_wos_strides_transposed[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = + e_g_n_c_wis_strides_transposed[1] * conv_N_per_block_; + + num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; + + if constexpr(NeedTransposeKernel) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, num_workgroups_per_Conv_N_); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_c_wis_lengths, e_g_n_c_wis_strides, num_workgroups_per_Conv_N_); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapInOutElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapWeiElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapInOutElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + + compute_ptr_offset_of_workspace_n_.BatchStrideA_ = + a_g_n_k_wos_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_workspace_n_.BatchStrideE_ = + e_g_n_c_wis_strides[1] * conv_N_per_block_; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t a_acum = ck::accumulate_n( + a_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) + { + std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + << std::endl; + + std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + << std::endl; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << ds_grid_desc_m_n_container_[i][j] << std::endl; + }); + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_m_n_container_[i] << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t conv_N_per_block_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector ds_grid_desc_m_n_container_; + std::vector e_grid_desc_m_n_container_; + + // block-to-e-tile map + Block2TileMapInOutElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_e_; + Block2TileMapWeiElementwise elementwise_block_2_ctile_map_transpose_b_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_workspace_n_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + std::array a_g_n_k_wos_lengths_; + std::array b_g_k_c_xs_lengths_; + std::array e_g_n_c_wis_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + + const index_t k_batch_; + index_t num_workgroups_per_Conv_N_; + std::vector gemms_grid_size_; + index_t gemms_count_ = 0; + std::vector> gemm_kernel_args_; + + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t gdy = arg.num_group_; + const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + if constexpr(NeedTransposeKernel) + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + } + } + for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); + gemm_set_id++) + { + const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; + const index_t gemms_count_for_set = + gemm_set_id == arg.gemm_kernel_args_.size() - 1 + ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id + : MaxGroupedGemmGroupsNum; + const std::array& gemm_kernel_args = + arg.gemm_kernel_args_[gemm_set_id]; + + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && gemm_set_id == 0) + { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } + + auto launch_kernel = [&](auto has_main_k_block_loop, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemmCTranspose, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + MaxGroupedGemmGroupsNum, + GemmArgs, + BElementwiseOp, + AElementwiseOp, + CDEElementwiseOp, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_b_grid, + p_a_grid, + arg.p_ds_grid_, + p_e_grid, + gemm_kernel_args, + gemms_count_for_set, + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_data_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + MaxGroupedGemmGroupsNum, + GemmArgs, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop, + CTranspose>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + gemm_kernel_args, + gemms_count_for_set, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_batch_, + arg.compute_ptr_offset_of_n_, + arg.k_batch_); + } + }; + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + // Transpose from NGKHW to NHWGK + if constexpr(NeedTransposeKernel) + { + EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync(p_e_in_grid, + 0, + arg.GetWorkspaceETensorSizeBytes(), + stream_config.stream_id_)); + }; + + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_) * + arg.num_workgroups_per_Conv_N_; + const index_t b_grid_size = + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_) + : 0; // Dont run transpose B if not needed + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + auto kernel_transpose = + kernel_elementwise_batched_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapInOutElementwise, + Block2TileMapWeiElementwise, + element_wise::PassThrough, + I1, + I1, + I1, + I1>; + + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size, + arg.num_workgroups_per_Conv_N_, + I1, // B is not splited per N + std::array{ + static_cast(arg.compute_ptr_offset_of_workspace_n_.BatchStrideA_)}, + std::array{0}, + std::array{ + static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, + std::array{0}); + } + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm(arg, stream_config); + } + + // Transpose from NHWGC to NGCHW + if constexpr(NeedTransposeKernel) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_) * + arg.num_workgroups_per_Conv_N_; + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = + kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapInOutElementwise, + element_wise::PassThrough, + I1, + I1>; + + ave_time += launch_and_time_kernel( + stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}, + arg.num_workgroups_per_Conv_N_, + std::array{ + static_cast(arg.compute_ptr_offset_of_n_.BatchStrideE_)}, + std::array{static_cast( + arg.compute_ptr_offset_of_workspace_n_.BatchStrideE_)}); + } + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && + arg.k_batch_ > 1) + { + return false; + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ != 1) + { + return false; + } + } + + const index_t ConvG = arg.b_g_k_c_xs_lengths_[0]; + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + // Specifialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || NeedTransposeKernel) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else if(is_same_v || + is_same_v) + { + static_assert(NeedTransposeKernel == false); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + return false; + } + if(output_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + } + else + { + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // vector store for Ds + bool ds_valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(CTranspose == false) + { + // vector load D matrix from global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + ds_valid = false; + } + } + else + { + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + ds_valid = false; + } + } + } + else + { + ds_valid = false; + } + }); + + if(!ds_valid) + { + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v) + { + if(CTranspose == false) + { + // vector store C matrix into global memory + if(!(ConvC % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + } + else + { + return false; + } + + // Gridwise GEMM size + for(std::size_t i = 0; i < arg.a_grid_desc_m_k_container_.size(); i++) + { + if(!GridwiseGemmCTranspose::CheckValidity( + arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.gemm_kernel_args_[i / MaxGroupedGemmGroupsNum][i % MaxGroupedGemmGroupsNum] + .block_2_ctile_map_, + arg.k_batch_)) + { + return false; + } + } + + if constexpr(NeedTransposeKernel) + { + if((ConvG * ConvC) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if((ConvG * ConvK) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + const index_t a_spatial_acum = ck::accumulate_n( + arg.a_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t e_spatial_acum = ck::accumulate_n( + arg.e_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(a_spatial_acum % TransposeTransferInScalarPerVectorAligned != 0) + { + return false; + } + + if(e_spatial_acum % TransposeTransferOutScalarPerVectorAligned != 0) + { + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout + << "Warning: Workspace for " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { + str << ", TransposeTransferInScalarPerVectorAligned: " + << TransposeTransferInScalarPerVectorAligned <<", " + << "TransposeTransferOutScalarPerVectorAligned: " << TransposeTransferOutScalarPerVectorAligned; + } + + + str << ">"; + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp new file mode 100755 index 000000000000..0b3f1a025567 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp @@ -0,0 +1,1234 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_dlops_bwd_weight( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, + const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + __shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)]; + + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_kbatch_k0_m0_m1_k1, + b_grid_desc_kbatch_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_kbatch_k0_m0_m1_k1; + ignore = b_grid_desc_kbatch_k0_n0_n1_k1; + ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Dl; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto spatial_offset = I3; + + static constexpr auto K1Number = Number{}; + static constexpr auto GemmK1Number = K1Number; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) + { + using namespace ck; + + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset]; + const index_t InLeftPadW = input_left_pads[I0]; + const index_t InRightPadW = input_right_pads[I0]; + const index_t ConvStrideW = conv_filter_strides[I0]; + const index_t ConvDilationW = conv_filter_dilations[I0]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); + + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wi, C), make_tuple(InWStride, InCStride)); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weights tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Wi, C), make_tuple(InNStride, InWStride, InCStride)); + + // A: output tensor + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + + } // function end + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) + { + using namespace ck; + + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I1]; + + const index_t InLeftPadH = input_left_pads[I0]; + const index_t InLeftPadW = input_left_pads[I1]; + const index_t InRightPadH = input_right_pads[I0]; + const index_t InRightPadW = input_right_pads[I1]; + const index_t ConvStrideH = conv_filter_strides[I0]; + const index_t ConvStrideW = conv_filter_strides[I1]; + const index_t ConvDilationH = conv_filter_dilations[I0]; + const index_t ConvDilationW = conv_filter_dilations[I1]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Hi * Wi, C), make_tuple(InWStride, InCStride)); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(InNStride, InHStride, InWStride, InCStride)); + + // A: output tensor + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + + } // function end + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t batch_k) + { + using namespace ck; + + const index_t N = a_g_n_c_wis_lengths[I1]; + const index_t K = b_g_k_c_xs_lengths[I1]; + const index_t C = a_g_n_c_wis_lengths[I2]; + const index_t Di = a_g_n_c_wis_lengths[spatial_offset + I0]; + const index_t Hi = a_g_n_c_wis_lengths[spatial_offset + I1]; + const index_t Wi = a_g_n_c_wis_lengths[spatial_offset + I2]; + const index_t Do = e_g_n_k_wos_lengths[spatial_offset + I0]; + const index_t Ho = e_g_n_k_wos_lengths[spatial_offset + I1]; + const index_t Wo = e_g_n_k_wos_lengths[spatial_offset + I2]; + const index_t Z = b_g_k_c_xs_lengths[spatial_offset + I0]; + const index_t Y = b_g_k_c_xs_lengths[spatial_offset + I1]; + const index_t X = b_g_k_c_xs_lengths[spatial_offset + I2]; + + const index_t InLeftPadD = input_left_pads[I0]; + const index_t InLeftPadH = input_left_pads[I1]; + const index_t InLeftPadW = input_left_pads[I2]; + const index_t InRightPadD = input_right_pads[I0]; + const index_t InRightPadH = input_right_pads[I1]; + const index_t InRightPadW = input_right_pads[I2]; + const index_t ConvStrideD = conv_filter_strides[I0]; + const index_t ConvStrideH = conv_filter_strides[I1]; + const index_t ConvStrideW = conv_filter_strides[I2]; + const index_t ConvDilationD = conv_filter_dilations[I0]; + const index_t ConvDilationH = conv_filter_dilations[I1]; + const index_t ConvDilationW = conv_filter_dilations[I2]; + + const auto InNStride = a_g_n_c_wis_strides[I1]; + const auto InCStride = a_g_n_c_wis_strides[I2]; + const auto InDStride = a_g_n_c_wis_strides[spatial_offset]; + const auto InHStride = a_g_n_c_wis_strides[spatial_offset + I1]; + const auto InWStride = a_g_n_c_wis_strides[spatial_offset + I2]; + const auto WeiKStride = b_g_k_c_xs_strides[I1]; + const auto WeiCStride = b_g_k_c_xs_strides[I2]; + const auto OutKStride = e_g_n_k_wos_strides[I2]; + const auto OutWStride = e_g_n_k_wos_strides[spatial_offset + I2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Di * Hi * Wi, C), make_tuple(InWStride, InCStride)); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = make_naive_tensor_descriptor( + make_tuple(N * Do * Ho * Wo, K), make_tuple(OutWStride, OutKStride)); + const auto in_n_di_hi_wi_c_grid_desc = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(InNStride, InDStride, InHStride, InWStride, InCStride)); + + // A: output tensor + const auto out_gemmkpad_gemmmpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, MPerBlock), + Sequence{}); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmmpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(out_gemmkpad_gemmmpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(GemmK1Number * K0PerBlock * GemmKBatch, NPerBlock), + Sequence{}); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmnpad_grid_desc, + make_tuple( + make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(in_gemmkpad_gemmnpad_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, Z * Y * X * C), make_tuple(WeiKStride, WeiCStride)); + + const auto wei_gemmmpad_gemmnpad_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(MPerBlock, NPerBlock), + Sequence{}); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmmpad_gemmnpad_grid_desc); + } + + } // function end + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, 1); + } + + template ::type = false> + static auto GetABCGridDesc() + { + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>({1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_B_K0_M_K1 = remove_cvref_t; + using BGridDesc_B_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = + GridwiseGemmDl_bkm_bkn_mn_v1r3; + + // Argument + using AGridDesc_B_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})); + using BGridDesc_B_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{out_element_op}, + b_element_op_{wei_element_op}, + c_element_op_{in_element_op}, + Conv_G_{a_g_n_c_wis_lengths[I0]}, + Conv_K_{b_g_k_c_xs_lengths[I1]}, + Conv_C_{a_g_n_c_wis_lengths[I2]}, + filter_lengths_{b_g_k_c_xs_lengths}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + a_grid_desc_kbatch_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_B_K0_M0_M1_K1(a_grid_desc_kbatch_k0_m_k1_); + b_grid_desc_kbatch_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_B_K0_N0_N1_K1(b_grid_desc_kbatch_k0_n_k1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + ck::index_t M01 = 1; + ck::index_t N01 = 1; + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = b_g_k_c_xs_strides[I0]; + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + AGridDesc_B_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_B_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1_; + BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + // DefaultBlock2CTileMap block_2_ctile_map_; + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + + // element-wise op + OutElementwiseOperation a_element_op_; + WeiElementwiseOperation b_element_op_; + InElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_K_; + const index_t Conv_C_; + + std::array filter_lengths_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + + ShowInfo(arg); + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm GridwiseGemmDl_bkm_bkn_mn_v1r3 has invalid setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = kernel_batched_gemm_dlops_bwd_weight< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + CDataType, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch<>, + has_main_loop, + has_double_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m0_m1_k1_, + arg.b_grid_desc_kbatch_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m0_m1_k1_.GetLength(I1); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + + // DL version only supports split_k equal to 1 + if(arg.k_batch_ != 1) + return false; + + if constexpr(!((NDimSpatial == 1 && + (is_NWGC_GKXC_NWGK() || + is_GNWC_GKXC_GNWK())) || + (NDimSpatial == 2 && + (is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) || + (NDimSpatial == 3 && + (is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())))) + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_lengths_[spatial_offset + i] == 1 && + arg.conv_filter_strides_[i] == 1 && arg.input_left_pads_[i] == 0 && + arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // matrix A + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I2] != 1 || srcVectorLengths[I3] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I4] != 0 || K0PerBlock % srcVectorLengths[I1] != 0) + { + return false; + } + + const index_t K = arg.Conv_K_; + + if(K % (srcVectorLengths[I1] * srcVectorLengths[I4]) != 0) + { + return false; + } + } + + // matrix B + { + auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{}; + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I4] != 1) + { + return false; + } + if(srcLoadLenghts[I2] % srcVectorLengths[I2] != 0 || + srcLoadLenghts[I3] % srcVectorLengths[I3] != 0) + { + return false; + } + + const index_t C = arg.Conv_K_; + + if(C % (srcVectorLengths[I2] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0)) + { + std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = " + << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl; + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp new file mode 100755 index 000000000000..a819b91b053c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -0,0 +1,486 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceGroupedConvBwdWeight_Explicit_Xdl + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr bool IsTwoStageNeeded = + sizeof(WeiDataType) % 4 != 0 && + DeviceGemmV3Op::CDEShuffleBlockTransferScalarPerVectors_::At(I0) % 2 != 0; + + using DeviceOp = DeviceGroupedConvBwdWeight_Explicit_Xdl; + using TwoStageIntermediateType = typename DeviceGemmV3Op::CDataType_; + + static constexpr index_t ElementwiseBlockSize = 256; + static constexpr index_t ElemsPerBlock = 256; + + static auto GetElementwiseCGridDesc(index_t merged_filter_dims) + { + const auto padd_size = merged_filter_dims % ElemsPerBlock == 0 + ? 0 + : ElemsPerBlock - merged_filter_dims % ElemsPerBlock; + const auto desc = make_naive_tensor_descriptor_packed(make_tuple(I1, merged_filter_dims)); + return transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(I1), + make_right_pad_transform(merged_filter_dims, padd_size)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + using CElementwiseGridDesc = remove_cvref_t; + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt<1, ElemsPerBlock>; + using GridwiseElementwiseCast = GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + WeiElementwiseOperation, + ElementwiseBlockSize, + I1, + ElemsPerBlock, + I1, + ElemsPerBlock / ElementwiseBlockSize, + Sequence<0, 1>, + Sequence<1>, + Sequence<1>, + I1, + I1>; + + struct Argument : public BaseArgument + { + using GemmArgument = typename DeviceGemmV3Op::Argument; + + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array&, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array&, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : filter_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + p_wei_grid_{p_wei_grid} + { + constexpr index_t spatial_offset = 3; + const index_t DoHoWo = std::accumulate(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + index_t{1}, + std::multiplies<>{}); + const index_t M = e_g_k_c_xs_lengths[I1]; + const index_t N = e_g_k_c_xs_lengths[I2]; + const index_t K = a_g_n_k_wos_lengths[I1] * DoHoWo; + + const index_t StrideOut = a_g_n_k_wos_strides[spatial_offset + NDimSpatial - 1]; + const index_t StrideIn = b_g_n_c_wis_strides[spatial_offset + NDimSpatial - 1]; + const index_t StrideWei = e_g_k_c_xs_strides[I1]; + const index_t StrideBatchOut = a_g_n_k_wos_strides[I0]; + const index_t StrideBatchIn = b_g_n_c_wis_strides[I0]; + const index_t StrideBatchWei = e_g_k_c_xs_strides[I0]; + + const index_t BatchSize = a_g_n_k_wos_lengths[I0]; + + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + + if constexpr(IsTwoStageNeeded) + { + const index_t merged_filter_dims = std::accumulate(begin(e_g_k_c_xs_lengths), + end(e_g_k_c_xs_lengths), + index_t{1}, + std::multiplies<>{}); + elementwise_desc_ = GetElementwiseCGridDesc(merged_filter_dims); + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{1, merged_filter_dims}; + // Check if stride to last dimension is product of all other dimensions. Then it is + // packed. + is_filter_data_packed = + e_g_k_c_xs_strides[0] == (merged_filter_dims / e_g_k_c_xs_lengths[0]); + + // Data type is modified during launch. It is checked in IsSupported if user + // allocated workspace + explicit_gemm_args = GemmArgument{p_out_grid, + p_in_grid, + {}, + static_cast(nullptr), + M, + N, + K, + StrideOut, + StrideIn, + {}, + StrideWei, + StrideBatchOut, + StrideBatchIn, + {}, + StrideBatchWei, + BatchSize, + out_element_op, + in_element_op, + wei_element_op, + split_k}; + } + else + { + explicit_gemm_args = GemmArgument{p_out_grid, + p_in_grid, + {}, + p_wei_grid, + M, + N, + K, + StrideOut, + StrideIn, + {}, + StrideWei, + StrideBatchOut, + StrideBatchIn, + {}, + StrideBatchWei, + BatchSize, + out_element_op, + in_element_op, + wei_element_op, + split_k}; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(IsTwoStageNeeded) + { + return sizeof(TwoStageIntermediateType) * elementwise_desc_.GetElementSpaceSize(); + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + if constexpr(IsTwoStageNeeded) + { + return GetWorkspaceETensorSizeBytes(); + } + else + { + return 0; + } + } + + GemmArgument explicit_gemm_args; + std::array filter_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + WeiDataType* p_wei_grid_; + bool is_filter_data_packed; + CElementwiseGridDesc elementwise_desc_; + Block2TileMapElementwise elementwise_block_2_ctile_map_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + using GemmArgument = typename DeviceGemmV3Op::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if constexpr(IsTwoStageNeeded) + { + // Modify to use workspace as output + GemmArgument explicit_gemm_args_with_workspace = arg.explicit_gemm_args; + explicit_gemm_args_with_workspace.p_c_grid = + static_cast(arg.p_workspace_); + float avg_time = + explicit_gemm_op.Run(explicit_gemm_args_with_workspace, stream_config); + const index_t grid_size = + arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.elementwise_desc_); + const auto kernel = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + WeiElementwiseOperation>; + + avg_time += launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(ElementwiseBlockSize), + 0, + make_tuple(arg.elementwise_desc_), + make_tuple(arg.elementwise_desc_), + make_tuple(static_cast(arg.p_workspace_)), + make_tuple(arg.p_wei_grid_), + arg.elementwise_block_2_ctile_map_, + element_wise::PassThrough{}); + return avg_time; + } + else + { + return explicit_gemm_op.Run(arg.explicit_gemm_args, stream_config); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + typename DeviceGemmV3Op::Invoker explicit_gemm_op; + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(NDimSpatial == 2) + { + if constexpr(!is_NHWGC_GKYXC_NHWGK()) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!is_NDHWGC_GKZYXC_NDHWGK()) + { + return false; + } + } + else + { + return false; + } + + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + if constexpr(IsTwoStageNeeded) + { + if(!arg.is_filter_data_packed) + { + return false; + } + // Check this here, it allows to use other instances from factory even + // if workspace is not allocated + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + } + // Gridwise GEMM size + return DeviceGemmV3Op::IsSupportedArgument(arg.explicit_gemm_args); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Explicit_Xdl" + << "<" << DeviceGemmV3Op{}.GetTypeString() << ">"; + // clang-format on + + return str.str(); + } + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..672c7dd2f701 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight_multiple_d.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_bwd_weight( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t c_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = batch_count; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif // end of if (defined(__gfx9__)) +} + +template +struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle + : public DeviceGroupedConvBwdWeightMultipleD +{ + using DeviceOp = DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + static constexpr index_t NumDTensor = DsLayout::Size(); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemm{}; + + static constexpr index_t MaxScalarPerVectorFP32 = 4; + static constexpr index_t WorkspaceInOutScalarPerVector = + is_same_v + ? math::min(CBlockTransferScalarPerVector_NWaveNPerXdl, MaxScalarPerVectorFP32) + : CBlockTransferScalarPerVector_NWaveNPerXdl; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, + BDataType, + AccDataType, + AccDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + element_wise::PassThrough, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + WorkspaceInOutScalarPerVector, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&](auto) constexpr { return Number{}; }, + Number{}); + } + + static constexpr auto GetDsGridPointerTuple() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + return static_cast(nullptr); + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t X = ds_g_k_c_xs_lengths[i][I3]; + const index_t CStride = ds_g_k_c_xs_strides[I2]; + const index_t KStride = ds_g_k_c_xs_strides[I1]; + + const auto wei_grid_desc = make_naive_tensor_descriptor( + make_tuple(K, X * C), make_tuple(KStride, CStride)); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t Y = ds_g_k_c_xs_lengths[i][I3]; + const index_t X = ds_g_k_c_xs_lengths[i][I4]; + + const auto wei_grid_desc = + conv_to_gemm_transformer.template make_wei_grid_desc( + K, Y, X, C, ds_g_k_c_xs_strides[i]); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template ::type = false> + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides) + { + return generate_tuple( + [&](auto i) { + const index_t K = ds_g_k_c_xs_lengths[i][I1]; + const index_t C = ds_g_k_c_xs_lengths[i][I2]; + const index_t Z = ds_g_k_c_xs_lengths[i][I3]; + const index_t Y = ds_g_k_c_xs_lengths[i][I4]; + const index_t X = ds_g_k_c_xs_lengths[i][I5]; + + const auto wei_grid_desc = + conv_to_gemm_transformer.template make_wei_grid_desc( + K, Z, Y, X, C, ds_g_k_c_xs_strides[i]); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return wei_grid_desc; + } + else + { + const index_t GemmM = K; + const index_t GemmN = C * X * Y * Z; + const auto PadGemmM = + GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = + GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + return transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }, + Number{}); + } + + template + static void + InitElementwiseBatchStrides(const ComputePtrOffsetOfBatch& compute_ptr_offset_of_batch_, + std::array& input_batch_strides, + std::array& output_batch_strides) + { + input_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_; + output_batch_strides[I0] = compute_ptr_offset_of_batch_.BatchStrideC_; + + // input_batch_strides = {C, Ds...} + static_for<0, NumDTensor, 1>{}([&](auto i) { + input_batch_strides[i + 1] = compute_ptr_offset_of_batch_.BatchStrideDs_[i]; + }); + } + + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {})); + using CDGridDesc_M_N = decltype(concat_tuple(Tuple{}, DsGridDesc_M_N{})); + using DsGridPointerTuple = decltype(GetDsGridPointerTuple()); + using CDDataTypes = decltype(concat_tuple(Tuple{}, DsGridPointerTuple{})); + using EGridDesc_M_N = CGridDesc_M_N; + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + decltype(MakeElementwiseInputSequence()), + Sequence, + I1, + I1>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument( + const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_ds_grid_{}, + p_e_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(AccDataType); + + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + static_assert(is_same_v, "Not supported D data layout"); + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_k_c_xs_strides[i][0]; + }); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ds_grid_descs_tuple_ = + MakeDsGridDescriptor_M_N(ds_g_k_c_xs_lengths, ds_g_k_c_xs_strides); + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(ce_grid_desc_m_n_, M01, N01, k_batch_); + elementwise_block_2_ctile_map_ = Block2TileMapElementwise{ + ce_grid_desc_m_n_.GetLength(I0), ce_grid_desc_m_n_.GetLength(I1)}; + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + ce_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_); + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return sizeof(AccDataType) * ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_; + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + DsGridPointerTuple p_ds_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + DsGridDesc_M_N ds_grid_descs_tuple_; + + Block2CTileMap block_2_ctile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + long_index_t c_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); + } + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + auto launch_gemm_kernel = [&](auto has_main_k_block_loop) { + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * arg.Conv_G_; + + constexpr bool has_main_loop = has_main_k_block_loop.value; + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); + }; + + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, + BDataType, + AccDataType, + OutElementwiseOperation, + InElementwiseOperation, + element_wise::PassThrough, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + p_c_grid, + arg.a_element_op_, + arg.b_element_op_, + element_wise::PassThrough{}, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + const index_t grid_size = + arg.elementwise_block_2_ctile_map_.CalculateGridSize(arg.ce_grid_desc_m_n_) * + arg.Conv_G_; + + std::array input_batch_strides; + std::array output_batch_strides; + InitElementwiseBatchStrides( + arg.compute_ptr_offset_of_batch_, input_batch_strides, output_batch_strides); + + const auto kernel = kernel_batched_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + NumDTensor + I1, + I1>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.ce_grid_desc_m_n_), arg.ds_grid_descs_tuple_), + make_tuple(arg.ce_grid_desc_m_n_), + concat_tuple(make_tuple(p_c_grid), arg.p_ds_grid_), + arg.p_e_grid_, + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + input_batch_strides, + output_batch_strides); + }; + + float avg_time = 0; + if(has_main_k0_block_loop) + { + avg_time = launch_gemm_kernel(integral_constant{}); + } + else + { + avg_time = launch_gemm_kernel(integral_constant{}); + } + + avg_time += launch_elementwise_kernel(); + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWC_GKXC_GNWK()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0 && + arg.Conv_C_ % WorkspaceInOutScalarPerVector == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.ce_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + p_ds, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + ds_g_k_c_xs_lengths, + ds_g_k_c_xs_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& p_ds, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array, NumDTensor>& ds_g_k_c_xs_lengths, + const std::array, NumDTensor>& ds_g_k_c_xs_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + p_ds, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + ds_g_k_c_xs_lengths, + ds_g_k_c_xs_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp new file mode 100755 index 000000000000..c7c463f43d16 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -0,0 +1,1905 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + [[maybe_unused]] const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + using DeviceOp = DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using EDataType = WeiDataType; + + // If NGCHW then ADataType must be equal to BDataType + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || + is_same_v); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CDEElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer_v2 = + TransformConvBwdWeightToGemmV2{}; + + static constexpr auto conv_to_gemm_transformer_v1 = + TransformConvBwdWeightToGemm{}; + + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + template ::type = false> + static auto GetElementwiseCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch)[I2]; + } + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + using CElementwiseGridDesc_M_N = + remove_cvref_t())>; + + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + AccDataType, + AccDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + K1, + K1, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CBlockTransferScalarPerVector_NWaveNPerXdl, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseElementwiseCast = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + // NPerBlock is used for the first dim which is store dimension + // (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector). + // CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to NPerBlock so + // it is more flexible to use this dim for store dimension with such scalar + // per vector. + using GridwiseElementwiseWeightTransposeCast = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + Sequence, + Sequence<1>, + I1, + I0>; + + using GridwiseElementwiseTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_e_grid_{p_wei_grid}, + a_grid_desc_k0_m_k1_{}, + b_grid_desc_k0_n_k1_{}, + ce_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + cde_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(AccDataType); + + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + std::array a_g_n_k_wos_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array e_g_k_c_xs_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, + e_g_k_c_xs_strides); + + const auto descs = + conv_to_gemm_transformer_v2 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_k0_m_k1_ = descs[I0]; + b_grid_desc_k0_n_k1_ = descs[I1]; + ce_grid_desc_m_n_ = descs[I2]; + + ce_elementwise_grid_desc_m_n_ = + conv_to_gemm_transformer_v1 + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_)[I2]; + + const index_t GemmM = a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_k0_n_k1_.GetLength(I1); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ce_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + } + + elementwise_block_2_ctile_map_ = + is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW() + ? Block2TileMapElementwise{e_in_transpose_desc_.GetLength(I0), + e_in_transpose_desc_.GetLength(I1)} + : Block2TileMapElementwise{ce_grid_desc_m_n_.GetLength(I0), + ce_grid_desc_m_n_.GetLength(I1)}; + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + // Align to 128B + return math::integer_divide_ceil( + sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + return sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(); + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + // Align to 128B + return math::integer_divide_ceil(sizeof(AccDataType) * + ce_grid_desc_m_n_.GetElementSpaceSize() * Conv_G_, + 128) * + 128; + } + + std::size_t GetWorkspaceSizeBytes() const + { + // 1. We need to transpose A and B for NGCHW and NGKHW layouts + // 2. If C format is GKCYX then tranpose during second stage. + // If C format is GKYXC then just perform second stage. + // Due to the fact that E workspace is always needed, we + // allocate them as the first part of the workspace. + // [EWorkspace, AWorkspace, BWorkspace] + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + else + { + return GetWorkspaceETensorSizeBytes(); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + CGridDesc_M_N ce_grid_desc_m_n_; + CElementwiseGridDesc_M_N ce_elementwise_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2TileMapElementwise elementwise_block_2_ctile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_; + + NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + GKYXCTransposeDescType e_in_transpose_desc_; + GKCYXTransposeDescType e_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + long_index_t c_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.ce_grid_desc_m_n_{" << arg.ce_grid_desc_m_n_.GetLength(I0) << ", " + << arg.ce_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + AccDataType* p_c_grid = type_convert(arg.p_workspace_); + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType); + p_b_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / + sizeof(BDataType); + } + + // nullptr for output, will be set after workspace set + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, p_c_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_ / NumGroupsToMerge); + + float ave_time = 0; + + index_t k_grain = gemm_arg.KBatch * KPerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * (KPerBlock); + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + if(arg.k_batch_ > 1) + { + hip_check_error(hipMemsetAsync( + gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); + } + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + clear_workspace(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_k0_m_k1_, + arg.b_grid_desc_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + NumGroupsToMerge, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + auto launch_elementwise_kernel = [&]() { + const AccDataType* p_c_grid = type_convert(arg.p_workspace_); + + std::array in_out_batch_strides = { + static_cast(arg.compute_ptr_offset_of_batch_.BatchStrideC_)}; + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const auto kernel = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_); + } + else + { + const index_t grid_size = arg.elementwise_block_2_ctile_map_.CalculateGridSize( + arg.ce_elementwise_grid_desc_m_n_) * + arg.Conv_G_; + + const auto kernel = + kernel_batched_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + CDEElementwiseOperation, + I1, + I1>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(arg.ce_elementwise_grid_desc_m_n_), + make_tuple(p_c_grid), + make_tuple(arg.p_e_grid_), + arg.elementwise_block_2_ctile_map_, + arg.cde_element_op_, + arg.Conv_G_, + in_out_batch_strides, + in_out_batch_strides); + } + }; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const index_t grid_size_a = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t grid_size_b = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceETensorSizeBytes() / sizeof(ADataType); + BDataType* p_b_out_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceETensorSizeBytes() + arg.GetWorkspaceATensorSizeBytes()) / + sizeof(BDataType); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_a + grid_size_b), + dim3(BlockSize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + grid_size_a); + } + + avg_time += RunGemmV3(arg, stream_config); + avg_time += launch_elementwise_kernel(); + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t GemmM = arg.a_grid_desc_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_k0_n_k1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_k0_m_k1_.GetLength(I0) * arg.a_grid_desc_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // Check this here, it allows to use other instances from factory even + // if workspace is not allocated + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_NGCHW_NGKHW())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_NGCDHW_NGKDHW())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + if constexpr(NumGroupsToMerge > 1) + { + // support only if whole M and N can be proccessed on one block + if(!(GemmM <= MPerBlock && GemmN <= NPerBlock)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.Conv_K_ == 1)) + { + return false; + } + if(arg.Conv_G_ % NumGroupsToMerge != 0) + { + return false; + } + } + + const bool is_w_pad_zero = arg.input_left_pads_[NDimSpatial - 1] == 0 && + arg.input_right_pads_[NDimSpatial - 1] == 0; + const auto X = arg.filter_spatial_lengths_[NDimSpatial - 1]; + const bool XC_access_allowed = arg.Conv_G_ == 1 && + (arg.Conv_C_ * X) % BBlockTransferSrcScalarPerVector == 0 && + is_w_pad_zero; + + if(!((arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0 || XC_access_allowed) && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0)) + { + if(!(arg.Conv_K_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideA_ == 1 && + NumGroupsToMerge > 1)) + { + return false; + } + if(!(arg.Conv_C_ == 1 && arg.compute_ptr_offset_of_batch_.BatchStrideB_ == 1 && + NumGroupsToMerge > 1)) + { + return false; + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVector != 0) + { + return false; + } + + if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVector != 0) + { + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + return false; + } + + if(output_spatial_acum % TransposeTransferSrcScalarPerVector != 0) + { + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " + << NumGroupsToMerge; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { + str << ", TransposeTransferSrcScalarPerVector: " + << TransposeTransferSrcScalarPerVector <<", " + << "TransposeTransferDstScalarPerVector: " << TransposeTransferDstScalarPerVector; + } + + + str << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp new file mode 100755 index 000000000000..e9e02eae819d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp @@ -0,0 +1,871 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template ::type = false> +struct DeviceGroupedConvBwdWeight_Wmma_CShuffle + : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto GemmK1Number = Number{}; + static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock; + const index_t GemmKPad = GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Pad + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto GetABCGridDesc() + { + const index_t dim = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using CShuffleDataType = AccDataType; + + using GridwiseGemm = GridwiseGemmMultipleD_Wmma< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + Tuple<>, + CDataType, + // InMemory Data Descriptor + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + Tuple<>, + CGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + KPerBlock, + MPerWMMA, + NPerWMMA, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + true, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + true, + BBlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{})); + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{})); + + using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap( + CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{a_g_n_c_wis_lengths[0]}, + Conv_N_{a_g_n_c_wis_lengths[1]}, + Conv_K_{b_g_k_c_xs_lengths[1]}, + Conv_C_{a_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + constexpr index_t spatial_offset = 3; + std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset, + end(a_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset, + end(b_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset, + end(e_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + a_g_n_c_wis_strides, + b_g_k_c_xs_strides, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap( + c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void Print(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + Print(arg); + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid " + "setting"); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseGemm, + ADataType, + BDataType, + typename GridwiseGemm::DsGridPointer, + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch<>, + has_main_loop>; + + using EmptyTuple = Tuple<>; + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + EmptyTuple{}, // Ds + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock{}, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(has_main_k0_block_loop) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // TODO: Add support for split_k > 1 + if(arg.k_batch_ != 1) + { + return false; + } + + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's a 1x1 convolution with stride=1 and no padding + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& a_g_n_c_wis_lengths, // input + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, + const std::array& e_g_n_k_wos_lengths, // output + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + a_g_n_c_wis_lengths, // input + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, // weight + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, // output + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << CShuffleBlockTransferScalarPerVector_NPerBlock + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp new file mode 100755 index 000000000000..6c53161ded47 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -0,0 +1,1146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_gemm_xdlops_bwd_weight( + const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const index_t batch_count, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); + const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); + const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); + + __shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)]; + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = batch_count; + ignore = block_2_ctile_map; + ignore = compute_ptr_offset_of_batch; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceGroupedConvBwdWeight_Xdl_CShuffle + : public DeviceGroupedConvBwdWeight +{ + using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffle; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + // If NGCHW then ADataType must be equal to BDataType + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || + is_same_v); + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemm{}; + + // Bytes per 32 lds bank: 32 * 4 bytes + static constexpr auto BankLength = 128; + static constexpr auto ElePerBank = BankLength / sizeof(ADataType); + + // M1 & M0 + static constexpr auto ABlockLdsM1PerBlock = ElePerBank / K1; + static constexpr auto ABlockLdsM0PerBlock = MPerBlock / ABlockLdsM1PerBlock; + static constexpr auto ABlockLdsM1Padding = 4; + + // N1 & N0 + static constexpr auto BBlockLdsN1PerBlock = ElePerBank / K1; + static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock; + static constexpr auto BBlockLdsN1Padding = 4; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + static constexpr index_t ClusterLengthMPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + using Block2TileMapTranspose = BlockToCTileMap_M00_N0_M01Adapt; + + static constexpr index_t TransposeTransferSrcScalarPerVectorAligned = + std::min(NPerBlock / ClusterLengthNPerBlock, MaxTransposeTransferSrcScalarPerVector); + static constexpr index_t TransposeTransferDstScalarPerVectorAligned = + std::min(MPerBlock / ClusterLengthMPerBlock, MaxTransposeTransferDstScalarPerVector); + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + using GridwiseInOutTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapTranspose, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + // NPerBlock is used for the first dim which is store dimension + // (with CBlockTransferScalarPerVector_NWaveNPerXdl scalar per vector). + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapTranspose, + element_wise::PassThrough, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence<1>, + I1, + I0>; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight< + BlockSize, + ADataType, + BDataType, + AccDataType, + CDataType, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_K0_M_K1, + BGridDesc_K0_N_K1, + CGridDesc_M_N, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXdl, + NPerXdl, + K1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsAddExtraM, + ABlockLdsM1PerBlock, + ABlockLdsM0PerBlock, + ABlockLdsM1Padding, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsAddExtraN, + BBlockLdsN1PerBlock, + BBlockLdsN0PerBlock, + BBlockLdsN1Padding, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferScalarPerVector_NWaveNPerXdl, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + true, + true, + 1, + PipelineVersion::v1, + ComputeTypeA, + ComputeTypeB>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{})); + + using Block2CTileMap = + decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(WeiDataType); + + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + std::array a_g_n_k_wos_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(a_g_n_k_wos_lengths, + a_g_n_k_wos_strides); + std::array b_g_n_c_wis_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides(b_g_n_c_wis_lengths, + b_g_n_c_wis_strides); + std::array e_g_k_c_xs_strides_transposed = + conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, + e_g_k_c_xs_strides); + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides_transposed, + e_g_k_c_xs_strides_transposed, + a_g_n_k_wos_strides_transposed, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + block_2_ctile_map_ = + GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_); + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides_transposed[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = e_g_k_c_xs_strides_transposed[0]; + + if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, + b_grid_desc_kbatch_k0_n_k1_, + c_grid_desc_m_n_, + block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n_); + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + b_g_n_c_wis_lengths, b_g_n_c_wis_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + e_g_k_c_xs_lengths, e_g_k_c_xs_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapTranspose{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapTranspose{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapTranspose{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + // Align to 128B + return math::integer_divide_ceil( + sizeof(ADataType) * a_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + // Align to 128B + return math::integer_divide_ceil( + sizeof(BDataType) * b_in_transpose_desc_.GetElementSpaceSize(), 128) * + 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + return sizeof(CDataType) * e_in_transpose_desc_.GetElementSpaceSize(); + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + Block2CTileMap block_2_ctile_map_; + + Block2TileMapTranspose elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_; + + NGCHWTransposeDescType a_in_transpose_desc_, b_in_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, b_out_transpose_desc_; + + GKYXCTransposeDescType e_in_transpose_desc_; + GKCYXTransposeDescType e_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + long_index_t c_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + CDataType* p_e_grid = arg.p_c_grid_; + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(CDataType); + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const index_t grid_size_a = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t grid_size_b = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); + + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + ADataType* p_out_a_grid = type_convert(arg.p_workspace_); + BDataType* p_out_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapTranspose, + Block2TileMapTranspose, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_a + grid_size_b), + dim3(BlockSize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_out_a_grid), + make_tuple(p_out_b_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + grid_size_a); + } + + const index_t grid_size = + arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_; + + const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_batched_gemm_xdlops_bwd_weight< + GridwiseGemm, + ADataType, + BDataType, + CDataType, + OutElementwiseOperation, + InElementwiseOperation, + WeiElementwiseOperation, + remove_reference_t, + remove_reference_t, + remove_reference_t, + remove_reference_t, + ComputePtrOffsetOfStridedBatch<>, + has_main_loop>; + + const auto clear_workspace = [&]() { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); + }; + + avg_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_e_grid, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.Conv_G_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(has_main_k0_block_loop) + { + launch_kernel(integral_constant{}); + } + else + { + launch_kernel(integral_constant{}); + } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size_e = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const CDataType* p_e_in_grid = static_cast(p_e_grid); + + // Different data type for A and B is not supported + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapTranspose, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size_e), + dim3(BlockSize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(arg.p_c_grid_), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + if(!is_bf16_atomic_supported() && std::is_same_v) + { + return false; + } + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWC_GKXC_GNWK()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK() || + is_NGCHW_NGKHW())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK() || + is_NGCDHW_NGKDHW())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + + // vector load A/B matrix from global memory + if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + if((arg.Conv_G_ * arg.Conv_C_) % TransposeTransferDstScalarPerVectorAligned != 0) + { + return false; + } + + if((arg.Conv_G_ * arg.Conv_K_) % TransposeTransferDstScalarPerVectorAligned != 0) + { + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.input_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % TransposeTransferSrcScalarPerVectorAligned != 0) + { + return false; + } + + if(output_spatial_acum % TransposeTransferSrcScalarPerVectorAligned != 0) + { + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.b_out_transpose_desc_.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + } + + // Gridwise GEMM size + return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl; + + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) { + str << ", TransposeTransferSrcScalarPerVectorAligned: " + << TransposeTransferSrcScalarPerVectorAligned <<", " + << "TransposeTransferDstScalarPerVectorAligned: " << TransposeTransferDstScalarPerVectorAligned; + } + + + str << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdWeight_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp new file mode 100755 index 000000000000..f13a256d6bcd --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -0,0 +1,1382 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" +#include +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; +#endif // end of if (defined(__gfx9__) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_idx); +#else + ignore = karg; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; +#endif // end of if (defined(__gfx9__) +} + +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +template +struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 + : public DeviceGroupedConvBwdWeight +{ + static_assert(is_same_v); + static_assert(is_same_v); + static_assert(is_same_v); + + using DeviceOp = DeviceGroupedConvBwdWeight_Xdl_CShuffleV3; + + using ADataType = OutDataType; + using BDataType = InDataType; + using CDataType = WeiDataType; + + using AElementwiseOperation = OutElementwiseOperation; + using BElementwiseOperation = InElementwiseOperation; + using CElementwiseOperation = WeiElementwiseOperation; + + // TODO make A/B datatype different + using ABDataType = InDataType; + + static inline auto I0 = Number<0>{}; + static inline auto I1 = Number<1>{}; + static inline auto I2 = Number<2>{}; + static inline auto I3 = Number<3>{}; + static inline auto I4 = Number<4>{}; + static inline auto I5 = Number<5>{}; + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::Default; + static constexpr auto K1Number = Number{}; + + static constexpr auto conv_to_gemm_transformer = + TransformConvBwdWeightToGemmV2{}; + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1}; + const std::array strides{1, 1, 1, 1}; + const std::array params{1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1}; + const std::array strides{1, 1, 1, 1, 1}; + const std::array params{1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + template ::type = false> + static auto GetABCGridDesc() + { + const ck::index_t dim = 1; + const ck::index_t batch = 1; + const std::array lengths{1, 1, 1}; + const std::array strides{1, 1, 1, 1, 1, 1}; + const std::array params{1, 1, 1}; + return conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>( + dim, + dim, + dim, + lengths, + lengths, + lengths, + strides, + strides, + strides, + params, + params, + params, + params, + batch); + } + + using ABCGridDescs = decltype(GetABCGridDesc()); + + using AGridDesc_K0_M_K1 = remove_cvref_t; + using BGridDesc_K0_N_K1 = remove_cvref_t; + using CGridDesc_M_N = remove_cvref_t; + + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + CDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + K0PerBlock, + K1, + K1, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, + ABlockLdsAddExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, + BBlockLdsAddExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CBlockTransferScalarPerVector_NWaveNPerXdl, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + // Argument + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + CGridDesc_M_N{}, 1, 1)); + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const ck::index_t M01, + const ck::index_t N01, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + ck::index_t split_k) + : p_a_grid_{p_out_grid}, + p_b_grid_{p_in_grid}, + p_c_grid_{p_wei_grid}, + a_grid_desc_kbatch_k0_m_k1_{}, + b_grid_desc_kbatch_k0_n_k1_{}, + c_grid_desc_m_n_{}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_batch_{}, + M01_{M01}, + N01_{N01}, + a_element_op_{out_element_op}, + b_element_op_{in_element_op}, + c_element_op_{wei_element_op}, + Conv_G_{b_g_n_c_wis_lengths[0]}, + Conv_N_{b_g_n_c_wis_lengths[1]}, + Conv_K_{e_g_k_c_xs_lengths[1]}, + Conv_C_{b_g_n_c_wis_lengths[2]}, + input_spatial_lengths_{}, + filter_spatial_lengths_{}, + output_spatial_lengths_{}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + c_space_size_bytes = + ck::accumulate_n( + e_g_k_c_xs_lengths.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(WeiDataType); + + constexpr index_t spatial_offset = 3; + std::copy(begin(b_g_n_c_wis_lengths) + spatial_offset, + end(b_g_n_c_wis_lengths), + begin(input_spatial_lengths_)); + std::copy(begin(e_g_k_c_xs_lengths) + spatial_offset, + end(e_g_k_c_xs_lengths), + begin(filter_spatial_lengths_)); + std::copy(begin(a_g_n_k_wos_lengths) + spatial_offset, + end(a_g_n_k_wos_lengths), + begin(output_spatial_lengths_)); + + const auto descs = + conv_to_gemm_transformer + .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + Conv_N_, + Conv_K_, + Conv_C_, + input_spatial_lengths_, + filter_spatial_lengths_, + output_spatial_lengths_, + b_g_n_c_wis_strides, + e_g_k_c_xs_strides, + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + k_batch_); + + a_grid_desc_kbatch_k0_m_k1_ = descs[I0]; + b_grid_desc_kbatch_k0_n_k1_ = descs[I1]; + c_grid_desc_m_n_ = descs[I2]; + + // A/B/C Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = + Conv_K_ * Conv_C_ * + std::accumulate(begin(filter_spatial_lengths_), + end(filter_spatial_lengths_), + index_t{1}, + std::multiplies<>{}); + const index_t GemmM = a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_, + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + } + + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; + CGridDesc_M_N c_grid_desc_m_n_; + CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t M01_; + index_t N01_; + + OutElementwiseOperation a_element_op_; + InElementwiseOperation b_element_op_; + WeiElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + const index_t Conv_G_; + const index_t Conv_N_; + const index_t Conv_K_; + const index_t Conv_C_; + std::array input_spatial_lengths_; + std::array filter_spatial_lengths_; + std::array output_spatial_lengths_; + const std::array& conv_filter_strides_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + const index_t k_batch_; + long_index_t c_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + void ShowInfo(const Argument& arg) + { + std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " + << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " + << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; + + std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", " + << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, arg.p_c_grid_, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( + gemm_arg.M, gemm_arg.N, gemm_arg.KBatch, arg.Conv_G_); + + float ave_time = 0; + + index_t k_grain = gemm_arg.KBatch * K0PerBlock; + index_t K_split = (gemm_arg.K + k_grain - 1) / k_grain * K0PerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto num_k_per_block = + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(Number<0>{}) / gemm_arg.KBatch; + + const auto clear_workspace = [&]() { + if(arg.k_batch_ > 1) + { + hip_check_error(hipMemsetAsync( + gemm_arg.p_c_grid, 0, arg.c_space_size_bytes, stream_config.stream_id_)); + } + }; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + clear_workspace(); + }; + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + else + { + ave_time += launch_and_time_kernel_with_preprocess( + stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_kbatch_k0_m_k1_, + arg.b_grid_desc_kbatch_k0_n_k1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + num_k_per_block); + } + }; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + } + + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(gemm_arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(gemm_arg.KBatch > 1) + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3< + GridwiseGemm, + remove_reference_t, + remove_reference_t, + remove_reference_t< + DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + } + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); + const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * + arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (K0PerBlock / K1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && + arg.k_batch_ > 1) + { + return false; + } + + if constexpr(NDimSpatial == 1) + { + if constexpr(!is_GNWC_GKXC_GNWK()) + { + return false; + } + } + else if constexpr(NDimSpatial == 2) + { + if constexpr(!(is_NHWGC_GKYXC_NHWGK() || + is_GNHWC_GKYXC_GNHWK())) + { + return false; + } + } + else if constexpr(NDimSpatial == 3) + { + if constexpr(!(is_NDHWGC_GKZYXC_NDHWGK() || + is_GNDHWC_GKZYXC_GNDHWK())) + { + return false; + } + } + else + { + return false; + } + + if constexpr(ConvBackwardWeightSpecialization == + ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { +// workaround: disable when K, C is even +#if CK_WORKAROUND_DISABLE_FILTER1x1STRIDE1PAD0_WHEN_K_C_IS_EVEN + if(arg.Conv_C_ % 2 == 0 || arg.Conv_K_ % 2 == 0) + { + return false; + } +#endif + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + return false; + } + } + } + if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 && + arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 && + arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + + // vector store C matrix into global memory + if(!(arg.Conv_C_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0)) + { + return false; + } + + // Gridwise GEMM size + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const InDataType* p_in_grid, + WeiDataType* p_wei_grid, + const OutDataType* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) + { + return Argument{p_in_grid, + p_wei_grid, + p_out_grid, + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in_grid, + void* p_wei_grid, + const void* p_out_grid, + const std::array& b_g_n_c_wis_lengths, // input + const std::array& b_g_n_c_wis_strides, + const std::array& e_g_k_c_xs_lengths, // weight + const std::array& e_g_k_c_xs_strides, + const std::array& a_g_n_k_wos_lengths, // output + const std::array& a_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + InElementwiseOperation in_element_op, + WeiElementwiseOperation wei_element_op, + OutElementwiseOperation out_element_op, + const ck::index_t split_k) override + { + return std::make_unique(static_cast(p_in_grid), + static_cast(p_wei_grid), + static_cast(p_out_grid), + b_g_n_c_wis_lengths, // input + b_g_n_c_wis_strides, + e_g_k_c_xs_lengths, // weight + e_g_k_c_xs_strides, + a_g_n_k_wos_lengths, // output + a_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + in_element_op, + wei_element_op, + out_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdWeight_Xdl_CShuffleV3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", " + << K1 << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..3e14f66a09c9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,1056 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_dl_multiple_d( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \ + defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); + + __shared__ ABDataType p_shared[shared_block_size]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = batch_count; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = e_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetEPtrOffset(0); +#endif +} +} // namespace + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template +struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + const auto AK0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + ds_grid_desc_m0_m10_m11_n0_n10_n11_{}, + e_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity( + a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, e_grid_desc_m_n_)) + { + + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_ak0_m_ak1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_bk0_n_bk1_); + e_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n_); + + ds_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + std::cout << "num_group: " << num_group_ << std::endl; + + std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl; + std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl; + std::cout << "A[m0, m10, m11, n0, n10, n11]: " << e_grid_desc_m0_m10_m11_n0_n10_n11_ + << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // block-to-e-tile map + DefaultBlock2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.e_grid_desc_m_n_.GetLength(I0), + arg.e_grid_desc_m_n_.GetLength(I1)) * + arg.num_group_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop; + + const auto kernel = kernel_grouped_conv_fwd_dl_multiple_d< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_K0_M0_M1_K1, + DeviceOp::BGridDesc_K0_N0_N1_K1, + DeviceOp::DsGridDesc_M0_M10_M11_N0_N10_N11, + DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11, + DefaultBlock2CTileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + has_double_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.ds_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.e_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + return 0; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X + << " ConvStride = " << ConvStride << " LeftPad = " << LeftPad + << " RightPad = " << RightPad << std::endl; + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X + << " LeftPad = " << LeftPad << " RightPad = " << RightPad + << std::endl; + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.e_grid_desc_m_n_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp new file mode 100755 index 000000000000..50e171e503cd --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,833 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC) + : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideC_); + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + index_t BatchStrideC_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link + * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of + * pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_dl( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const index_t batch_count, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \ + defined(__gfx11__) || defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType); + + __shared__ ABDataType p_shared[shared_block_size]; + + GridwiseGemm::Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m0_m1_k1; + ignore = b_grid_desc_k0_n0_n1_k1; + ignore = c_grid_desc_m0_m10_m11_n0_n10_n11; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; + + compute_ptr_offset_of_batch.GetAPtrOffset(0); + compute_ptr_offset_of_batch.GetBPtrOffset(0); + compute_ptr_offset_of_batch.GetCPtrOffset(0); +#endif +} + +} // namespace + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template < + index_t NDimSpatial, + typename ADataType, + typename BDataType, + typename CDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + ConvolutionForwardSpecialization ConvForwardSpecialization, + GemmSpecialization GemmSpec, + index_t BlockSize, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t K1, + index_t M1PerThread, + index_t N1PerThread, + index_t KPerThread, + typename M1N1ThreadClusterM1Xs, + typename M1N1ThreadClusterN1Xs, + typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, + typename ABlockTransferSrcVectorTensorContiguousDimOrder, + typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, + typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, + typename BBlockTransferSrcVectorTensorContiguousDimOrder, + typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, + typename CThreadTransferSrcDstAccessOrder, + index_t CThreadTransferSrcDstVectorDim, + index_t CThreadTransferDstScalarPerVector, + enable_if_t< + is_same_v && + is_same_v && + is_same_v, + bool> = false> +struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd +{ + using DeviceOp = DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, K0PerBlock}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, K1)), make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, K1)), make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto MakeCGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using CGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDl_km_kn_mn_v1r3; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = + decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + void* p_c, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_c_grid_{static_cast(p_c)}, + num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_ak0_m_ak1_{ + DeviceOp::MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + DeviceOp::MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + c_grid_desc_m_n_{ + DeviceOp::MakeCGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_k0_m0_m1_k1_{}, + b_grid_desc_k0_n0_n1_k1_{}, + c_grid_desc_m0_m10_m11_n0_n10_n11_{}, + block_2_ctile_map_{}, + compute_ptr_offset_of_batch_{ + a_g_n_c_wis_strides[0], b_g_k_c_xs_strides[0], c_g_n_k_wos_strides[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + c_g_n_k_wos_lengths_{c_g_n_k_wos_lengths}, + c_g_n_k_wos_strides_{c_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = c_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity( + a_grid_desc_ak0_m_ak1_, b_grid_desc_bk0_n_bk1_, c_grid_desc_m_n_)) + { + + a_grid_desc_k0_m0_m1_k1_ = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_ak0_m_ak1_); + b_grid_desc_k0_n0_n1_k1_ = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_bk0_n_bk1_); + c_grid_desc_m0_m10_m11_n0_n10_n11_ = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(c_grid_desc_m_n_); + + block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "C[M, N]: " << c_grid_desc_m_n_ << std::endl; + std::cout << "num_group: " << num_group_ << std::endl; + + std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl; + std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl; + std::cout << "A[m0, m10, m11, n0, n10, n11]: " << c_grid_desc_m0_m10_m11_n0_n10_n11_ + << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + CDataType* p_c_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + CGridDesc_M_N c_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_; + + // block-to-e-tile map + DefaultBlock2CTileMap block_2_ctile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array c_g_n_k_wos_lengths_; + std::array c_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + // if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK has invalid setting"); + } + + const index_t grid_size = + GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_.GetLength(I0), + arg.c_grid_desc_m_n_.GetLength(I1)) * + arg.num_group_; + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop; + + const auto kernel = + kernel_grouped_conv_fwd_dl; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_k0_m0_m1_k1_, + arg.b_grid_desc_k0_n0_n1_k1_, + arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, + arg.block_2_ctile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + const bool has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(has_main_k_block_loop && !has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!has_main_k_block_loop && has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() || + ck::is_gfx11_supported() || ck::is_gfx12_supported())) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X + << " ConvStride = " << ConvStride << " LeftPad = " << LeftPad + << " RightPad = " << RightPad << std::endl; + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + std::cout << "Filter1x1Stride1Pad0 check: i = " << i << " X = " << X + << " LeftPad = " << LeftPad << " RightPad = " << RightPad + << std::endl; + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{}; + if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1) + { + return false; + } + if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0) + { + return false; + } + + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0) + { + return false; + } + } + else + { + return false; + } + + // check vector access of C + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.c_g_n_k_wos_lengths_[2]; + + if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5)) + { + return false; + } + } + else + { + return false; + } + // check Gridwise GEMM + return GridwiseGemm::CheckValidity( + arg.a_grid_desc_ak0_m_ak1_, arg.b_grid_desc_bk0_n_bk1_, arg.c_grid_desc_m_n_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + void* p_c, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + { + return Argument{p_a, + p_b, + p_c, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + c_g_n_k_wos_lengths, + c_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array& c_g_n_k_wos_lengths, + const std::array& c_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) override + { + return std::make_unique(p_a, + p_b, + p_c, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + c_g_n_k_wos_lengths, + c_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + c_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << ABlockTransferDstScalarPerVector_K1 << ", " + << BBlockTransferSrcScalarPerVector << ", " + << BBlockTransferDstScalarPerVector_K1 << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << CBlockTransferScalarPerVector_NWaveNPerXdl + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp new file mode 100755 index 000000000000..6d2988ba242a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1953 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle( + AsPointer p_as_grid, + BsPointer p_bs_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfG compute_ptr_offset_of_groups, + const ComputePtrOffsetOfN compute_ptr_offset_of_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx); + + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_group_offset[i]; }); + + if constexpr(isMultiA || isMultiB) + { + AsPointer p_as_grid_grp; + BsPointer p_bs_grid_grp; + + const auto& as_group_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx); + + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + if constexpr(isMultiA) + { + const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx); + + static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size(); + static_for<0, NumATensor, 1>{}([&](auto i) { + p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + as_n_offset[i]; + }); + } + else + { + const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx); + static_for<0, 1, 1>{}( + [&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_group_offset[i] + a_n_offset; }); + } + + const auto& bs_group_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx); + + static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size(); + static_for<0, NumBTensor, 1>{}( + [&](auto i) { p_bs_grid_grp(i) = p_bs_grid[i] + bs_group_offset[i]; }); + + GridwiseGemm::template Run( + p_as_grid_grp, + p_bs_grid_grp, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } + else + { + const long_index_t b_group_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t a_group_offset = + CTranspose + ? amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)) + : amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_n_offset = + CTranspose ? amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)) : 0; + const long_index_t a_n_offset = + CTranspose ? 0 : amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + + GridwiseGemm::template Run( + p_as_grid + a_group_offset + a_n_offset, + p_bs_grid + b_group_offset + b_n_offset, + p_ds_grid_grp, + p_e_grid + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map); + } +#else + ignore = p_as_grid; + ignore = p_bs_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; + ignore = block_2_ctile_map; +#endif +} + +} // namespace +#ifdef CK_CODE_GEN_RTC +template +using is_tuple = decltype(ck::declval().IsTuple()); +#else +template +using is_tuple = decltype(std::declval().IsTuple()); +#endif + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler(), + index_t NumGroupsToMerge = 1> +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; + + static_assert(NumGroupsToMerge >= 1); + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiAB = isMultiA || isMultiB; + + // NGCHW is not supported for multiAB + static_assert(!(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) || + !(isMultiA || isMultiB)); + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr bool DoElementwiseBeforeCShuffle = + NumDTensor == 0 && !isMultiAB && is_same_v && + !is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr bool isATensorColMajor = + (ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) && + (ABlockTransferSrcVectorDim == 1) && (NumGroupsToMerge == 1) && + (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool NeedTransposeKernel = + (isATensorColMajor == false) && (is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()); + + static constexpr bool CTranspose = (NeedTransposeKernel == false) && (isMultiAB == false) && + (is_same_v || + is_same_v); + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGC, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGC, + ALay>>; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::GKYXC, + std::conditional_t() && NeedTransposeKernel, + ctc::GKZYXC, + BLay>>; + + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_NGKHW() && NeedTransposeKernel, + ctc::NHWGK, + std::conditional_t() && NeedTransposeKernel, + ctc::NDHWGK, + ELay>>; + + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + if constexpr(CTranspose) + { + constexpr auto matrix_padder_trans = + MatrixPadder{NPerBlock, MPerBlock, KPerBlock}; + return matrix_padder_trans.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } + else + { + return matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + } + } + + // Shape of Ds and E must be aligned. Strides can be different. + // Pass e_g_n_k_wos_lengths for logical broadcast. + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + // If we are using multiAB and one of the template datatype parameters is not a tuple, convert + // it to it + using GemmADataType = std::conditional_t, ADataType>; + using GemmBDataType = std::conditional_t, BDataType>; + +#define GridwiseGemmMultiABDTemplateParameters \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ + KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType, DoElementwiseBeforeCShuffle + +#define GridwiseGemmCTransposeTemplateParameters \ + GemmBDataType, GemmADataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, BElementwiseOperation, AElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, NPerBlock, MPerBlock, KPerBlock, BK1, AK1, NPerXDL, \ + MPerXDL, NXdlPerWave, MXdlPerWave, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType, DoElementwiseBeforeCShuffle + + // Use appropriate gridwise gemm + using GridwiseGemm = std::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemmCTranspose = std::conditional_t< + CTranspose, + GridwiseGemmMultipleD_xdl_cshuffle, + GridwiseGemm>; + + // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. + using APointers = + std::conditional_t&, const void*>; + using BPointers = + std::conditional_t&, const void*>; + // Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not + // in initializer list what is required for single const pointer). + using AGridPointer = remove_cvref_t< + decltype(GetAGridPointer < isMultiA || isMultiB, GridwiseGemm, ADataType > ())>; + using BGridPointer = remove_cvref_t< + decltype(GetBGridPointer < isMultiA || isMultiB, GridwiseGemm, BDataType > ())>; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemmCTranspose::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemmCTranspose::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + EGridDesc_M_N{}))>; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; + + // NPerBlock is used for the first and second dim which to use + // CDEBlockTransferScalarPerVector_NPerBlock for load and store during + // transposition. CBlockTransferScalarPerVector_NWaveNPerXdl is aligned to + // NPerBlock so it is more flexible to use this dim for load store dimension + // with such scalar per vector. + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + + // Argument + struct Argument : public BaseArgument + { + Argument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_as_grid_{}, + p_bs_grid_{}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides) + : a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides) + : b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{NeedTransposeKernel + ? conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides) + : e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + num_group_{a_g_n_c_wis_lengths_[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + conv_N_per_block_{conv_to_gemm_transformer_.N_}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{ + GridwiseGemmCTranspose::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // A/B/E Batch Stride + if constexpr(isMultiA || isMultiB) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideA_(i) = + a_g_n_c_wis_strides_[0] * NumGroupsToMerge; + + // Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data + // type is not tuple) + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiA) + { + // p_as is tuple + p_as_grid_(i) = static_cast(p_as[i.value]); + // compute_ptr_offset_of_n_ not need BatchStrideB so + // in case of MultiA is false but isMultiB is true + // BatchStrideA_ is not tuple. + compute_ptr_offset_of_n_.BatchStrideA_(i) = + a_g_n_c_wis_strides_[1] * conv_N_per_block_; + } + else + { + // if MultiB and not MultiA then p_as is single pointer + p_as_grid_(i) = static_cast(p_as); + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_c_wis_strides_[1] * conv_N_per_block_; + } + }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + // Init compute_ptr_offset_of_groups_ for multiple AB + compute_ptr_offset_of_groups_.BatchStrideB_(i) = + b_g_k_c_xs_strides_[0] * NumGroupsToMerge; + + using DataType = remove_cvref_t>; + // It is possible that one of the AB is a pointer and one is a tuple. + // Then also use multiAB but we have to cast single pointer instead of tuple of + // pointer. + if constexpr(isMultiB) + { + // p_bs is tuple + p_bs_grid_(i) = static_cast(p_bs[i.value]); + } + else + { + // if MultiA and not MultiB then p_bs is single pointer + p_bs_grid_(i) = static_cast(p_bs); + } + }); + } + else + { + compute_ptr_offset_of_groups_.BatchStrideA_ = + a_g_n_c_wis_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_groups_.BatchStrideB_ = + b_g_k_c_xs_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideA_ = + a_g_n_c_wis_strides_[1] * conv_N_per_block_; + + // p_as and p_bs are pointers + p_as_grid_(I0) = static_cast(p_as); + p_bs_grid_(I0) = static_cast(p_bs); + } + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_groups_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideDs_(i) = + ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + ds_g_n_k_wos_strides_[i], + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }); + compute_ptr_offset_of_groups_.BatchStrideE_ = + e_g_n_k_wos_strides_[0] * NumGroupsToMerge; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; + + // populate desc for Ds/E + if constexpr(isMultiA || isMultiB) + { + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return b_grid_desc_n_k_; }, Number{}); + + if(GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + } + else + { + bool valid = false; + if constexpr(CTranspose) + { + valid = GridwiseGemmCTranspose::CheckValidity(b_grid_desc_n_k_, + a_grid_desc_m_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + else + { + valid = GridwiseGemmCTranspose::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_); + } + if(valid) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = GridwiseGemmCTranspose:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + } + } + + if constexpr(NeedTransposeKernel) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t a_acum = ck::accumulate_n( + a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(NeedTransposeKernel) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers (tuple if multi AB, pointer if no) + AGridPointer p_as_grid_; + BGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + index_t conv_N_per_block_; + + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch + compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); + const index_t gdy = arg.num_group_ / NumGroupsToMerge; + const index_t gdz = num_workgroups_per_Conv_N; + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + if constexpr(isMultiA || isMultiB) + { + // Generate tuples with grid descriptors for each A and B + const auto as_grid_desc_ak0_m_ak1 = generate_tuple( + [&](auto) { return arg.a_grid_desc_ak0_m_ak1_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = generate_tuple( + [&](auto) { return arg.b_grid_desc_bk0_n_bk1_; }, Number{}); + + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + AGridPointer, + BGridPointer, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + decltype(as_grid_desc_ak0_m_ak1), + decltype(bs_grid_desc_bk0_n_bk1), + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB, + CTranspose>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.p_as_grid_, + arg.p_bs_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + else + { + const ADataType* p_a_grid = arg.p_as_grid_.At(I0); + const BDataType* p_b_grid = arg.p_bs_grid_.At(I0); + EDataType* p_e_grid = arg.p_e_grid_; + if constexpr(NeedTransposeKernel) + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + p_e_grid = type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + else if constexpr(is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_e_grid = type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + } + + if constexpr(CTranspose) + { + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemmCTranspose, + const BDataType*, + const ADataType*, + typename GridwiseGemm::DsGridPointer, + EDataType, + BElementwiseOperation, + AElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB, + CTranspose>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_b_grid, + p_a_grid, + arg.p_ds_grid_, + p_e_grid, + arg.b_element_op_, + arg.a_element_op_, + arg.cde_element_op_, + arg.b_grid_desc_bk0_n_bk1_, + arg.a_grid_desc_ak0_m_ak1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle< + GridwiseGemm, + const ADataType*, + const BDataType*, + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + has_main_loop, + isMultiA, + isMultiB, + CTranspose>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + arg.p_ds_grid_, + p_e_grid, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + } + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + + if constexpr(NeedTransposeKernel) + { + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t b_grid_size = + (is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + ? arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_) + : 0; // Dont run transpose B if not needed + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + auto kernel_transpose = kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_as_grid_.At(I0)), + make_tuple(arg.p_bs_grid_.At(I0)), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size); + } + + avg_time += RunGemm(arg, stream_config); + + if constexpr(NeedTransposeKernel) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + } + + if constexpr(NumGroupsToMerge > 1) + { + if(!(C == 1)) + { + return false; + } + if(G % NumGroupsToMerge != 0) + { + return false; + } + if constexpr(!(is_NSpatialGC_GKSpatial_NSpatialGK() || + is_NGCSpatial_GKSpatial_NGKSpatial() || + is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW())) + { + return false; + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + NeedTransposeKernel) + { + // Check access per C + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + // If not possible, check access per G + if(!(ABlockTransferSrcVectorDim == 1 && (C == 1 || NumGroupsToMerge == 1) && + (is_NSpatialGC_GKSpatial_NSpatialGK() || + is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) && + G % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + } + else if constexpr(is_same_v || is_same_v) + { + static_assert(NeedTransposeKernel == false); + static_assert(NumGroupsToMerge == 1); + + if constexpr(ABlockTransferSrcScalarPerVector != 1) + { + if(ABlockTransferSrcVectorDim != 1) + { + return false; + } + if(input_spatial_acum % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + + if constexpr(is_same_v) + { + // G and K must be the same + if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] || + arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2]) + { + valid = false; + } + } + else + { + // E and D must have the same shape + for(index_t d = 0; d < NDimSpatial + 3; d++) + { + if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d]) + { + valid = false; + } + } + } + } + else + { + valid = false; + } + }); + + if constexpr(NeedTransposeKernel) + { + if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + } + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(CTranspose == false) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + } + else + { + return false; + } + + // check Gridwise GEMM + if constexpr(isMultiA || isMultiB) + { + // Genarate tuples with the same descriptors + const auto as_grid_desc_ak0_m_ak1 = + generate_tuple([&](auto) { return arg.a_grid_desc_m_k_; }, Number{}); + const auto bs_grid_desc_bk0_n_bk1 = + generate_tuple([&](auto) { return arg.b_grid_desc_n_k_; }, Number{}); + return GridwiseGemm::CheckValidity(as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + if constexpr(CTranspose) + { + return GridwiseGemmCTranspose::CheckValidity(arg.b_grid_desc_n_k_, + arg.a_grid_desc_m_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + else + { + return GridwiseGemmCTranspose::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + } + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(APointers p_as, + BPointers p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << NumGroupsToMerge + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp new file mode 100755 index 000000000000..48424c16b921 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -0,0 +1,1698 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds( + typename GridwiseGemm::Argument karg, + [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, + [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, + [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups, + [[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // offset base pointer for each work-group + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset + e_n_offset, + p_shared_0, + p_shared_1, + karg, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType> +struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3; + + static constexpr bool isMultiA = is_detected::value; + static constexpr bool isMultiB = is_detected::value; + static constexpr bool isMultiD = DsDataType::Size() > 0; + static constexpr bool isMultiABD = isMultiA || isMultiB || isMultiD; + + static constexpr bool DoElementwiseBeforeCShuffle = + !isMultiABD && is_same_v && + !is_same_v; + + static constexpr index_t NumATensor = GetNumABTensors(); + static constexpr index_t NumBTensor = GetNumABTensors(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr auto conv_ngchw_to_nhwgc_transformer = + TransformConvNGCHWToNHWGC{}; + + template + static auto + MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGC, + std::conditional_t(), + ctc::NDHWGC, + ALay>>; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto + MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKCYX_NGKHW(), + ctc::GKYXC, + std::conditional_t(), + ctc::GKZYXC, + BLay>>; + + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + + { + namespace ctc = tensor_layout::convolution; + using Layout = std::conditional_t< + is_NGCHW_GKCYX_NGKHW(), + ctc::NHWGK, + std::conditional_t(), + ctc::NDHWGK, + ELay>>; + + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + +#define GridwiseGemmV3TemplateParams \ + tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \ + tensor_layout::gemm::RowMajor, ADataType, BDataType, AccDataType, CShuffleDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, \ + MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ + AComputeDataType, BComputeDataType, false, false, DoElementwiseBeforeCShuffle + + // Use appropriate gridwise gemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3; + + using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; + + using NGCHWTransposeDescType = + remove_cvref_t({}, {}))>; + using NHWGCTransposeDescType = + remove_cvref_t({}, {}))>; + + using GKCYXTransposeDescType = + remove_cvref_t({}, {}))>; + using GKYXCTransposeDescType = + remove_cvref_t({}, {}))>; + + static constexpr index_t ElementwiseBlocksize = ClusterLengthNPerBlock * ClusterLengthNPerBlock; + + using GridwiseElementwiseInputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I1, + I0>; + + using GridwiseElementwiseWeightTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence<1>, + Sequence, + I0, + I1>; + + using GridwiseElementwiseOutputTranspose = + GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMapElementwise, + element_wise::PassThrough, + ElementwiseBlocksize, + NPerBlock, + NPerBlock, + NPerBlock / ClusterLengthNPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<1, 0>, + Sequence, + Sequence, + I0, + I1>; + + static auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const index_t M = e_grid_desc_m_n.GetLength(I0); + const index_t N = e_grid_desc_m_n.GetLength(I1); + return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n, GridwiseGemm::CalculateMBlock(M), GridwiseGemm::CalculateNBlock(N)); + } + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t( + dummy_conv_to_gemm_transformer))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_as, + const void* p_bs, + const std::array&, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>&, + const std::array, NumDTensor>&, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{}, + p_b_grid_{}, + p_e_grid_{static_cast(p_e)}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides)}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides)}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{conv_ngchw_to_nhwgc_transformer.TransposeInOutStrides( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides)}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + num_group_{a_g_n_c_wis_lengths_[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + conv_N_per_block_{conv_to_gemm_transformer_.N_}, + a_grid_desc_ak0_m_ak1_{ + MakeAGridDescriptor_AK0_M_AK1(conv_to_gemm_transformer_)}, + b_grid_desc_bk0_n_bk1_{ + MakeBGridDescriptor_BK0_N_BK1(conv_to_gemm_transformer_)}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op} + { + // A/B/E Batch/N Stride + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0]; + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_; + + // p_as and p_bs are pointers + p_a_grid_ = static_cast(p_as); + p_b_grid_ = static_cast(p_bs); + + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0]; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_; + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + // Use not modified base strides + a_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + a_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides); + + b_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKCYXTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + b_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeGKYXCTransposeDesc( + b_g_k_c_xs_lengths, b_g_k_c_xs_strides); + + e_in_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNHWGCTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + elementwise_block_2_ctile_map_transpose_b_ = Block2TileMapElementwise{ + b_in_transpose_desc_.GetLength(I0), b_in_transpose_desc_.GetLength(I1)}; + e_out_transpose_desc_ = + conv_ngchw_to_nhwgc_transformer.template MakeNGCHWTransposeDesc( + e_g_n_k_wos_lengths, e_g_n_k_wos_strides); + + elementwise_block_2_ctile_map_transpose_a_ = Block2TileMapElementwise{ + a_in_transpose_desc_.GetLength(I0), a_in_transpose_desc_.GetLength(I1)}; + elementwise_block_2_ctile_map_transpose_e_ = Block2TileMapElementwise{ + e_in_transpose_desc_.GetLength(I0), e_in_transpose_desc_.GetLength(I1)}; + } + } + + std::size_t GetWorkspaceATensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const long_index_t a_acum = ck::accumulate_n( + a_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(ADataType) * a_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceBTensorSizeBytes() const + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const long_index_t b_acum = ck::accumulate_n( + b_g_k_c_xs_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + // Align to 128B + return math::integer_divide_ceil(sizeof(BDataType) * b_acum, 128) * 128; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceETensorSizeBytes() const + { + if constexpr(is_NGCHW_NGKHW() || + is_NGCDHW_NGKDHW()) + { + const long_index_t e_accum = ck::accumulate_n( + e_g_n_k_wos_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()); + return sizeof(EDataType) * e_accum; + } + else + { + return 0; + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + return GetWorkspaceATensorSizeBytes() + GetWorkspaceBTensorSizeBytes() + + GetWorkspaceETensorSizeBytes(); + } + + void Print() const + { + std::cout << "A[AK0, M, AK1]: " << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "B[BK0, N, BK1]: " << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers (tuple if multi AB, pointer if no) + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + index_t conv_N_per_block_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_M_N e_grid_desc_m_n_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // block-to-e-tile map + Block2TileMapElementwise elementwise_block_2_ctile_map_transpose_a_, + elementwise_block_2_ctile_map_transpose_b_, elementwise_block_2_ctile_map_transpose_e_; + + NGCHWTransposeDescType a_in_transpose_desc_, e_out_transpose_desc_; + NHWGCTransposeDescType a_out_transpose_desc_, e_in_transpose_desc_; + GKCYXTransposeDescType b_in_transpose_desc_; + GKYXCTransposeDescType b_out_transpose_desc_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float RunGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + float ave_time = 0; + + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; + + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/); + + gdy = arg.num_group_; + gdz = num_workgroups_per_Conv_N; + + index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock; + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + p_a_grid = type_convert(arg.p_workspace_); + p_b_grid = type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + p_e_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + } + + typename GridwiseGemm::Argument gemm_arg{p_a_grid, + p_b_grid, + p_e_grid, + GemmM, + GemmN, + GemmK, + I0, + I0, + I0, + I1, + false, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_}; + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; + ck::utility::RotatingMemWrapper rotating_mem( + gemm_arg_, + stream_config.rotating_count, + gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), + gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + }; + + ave_time += ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + else + { + ave_time += + launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + } + }; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::One>; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Full>; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_grouped_conv_fwd_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy>; + Run(kernel); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float avg_time = 0.f; + if constexpr(!isMultiABD) + { + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t a_grid_size = + arg.elementwise_block_2_ctile_map_transpose_a_.CalculateGridSize( + arg.a_in_transpose_desc_); + const index_t b_grid_size = + arg.elementwise_block_2_ctile_map_transpose_b_.CalculateGridSize( + arg.b_in_transpose_desc_); + + ADataType* p_a_out_grid = type_convert(arg.p_workspace_); + BDataType* p_b_out_grid = + type_convert(arg.p_workspace_) + + arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); + + auto kernel_transpose = + kernel_elementwise_dual, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += + launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(a_grid_size + b_grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.a_in_transpose_desc_), + make_tuple(arg.b_in_transpose_desc_), + make_tuple(arg.a_out_transpose_desc_), + make_tuple(arg.b_out_transpose_desc_), + make_tuple(arg.p_a_grid_), + make_tuple(arg.p_b_grid_), + make_tuple(p_a_out_grid), + make_tuple(p_b_out_grid), + arg.elementwise_block_2_ctile_map_transpose_a_, + arg.elementwise_block_2_ctile_map_transpose_b_, + element_wise::PassThrough{}, + a_grid_size); + } + + avg_time += RunGemm(arg, stream_config); + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + const index_t grid_size = + arg.elementwise_block_2_ctile_map_transpose_e_.CalculateGridSize( + arg.e_in_transpose_desc_); + + const EDataType* p_e_in_grid = + type_convert(arg.p_workspace_) + + (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / + sizeof(EDataType); + + EDataType* p_e_out_grid = arg.p_e_grid_; + + auto kernel_transpose = kernel_elementwise, + ck::Tuple, + ck::Tuple, + ck::Tuple, + Block2TileMapElementwise, + element_wise::PassThrough>; + + avg_time += + launch_and_time_kernel(stream_config, + kernel_transpose, + dim3(grid_size), + dim3(ElementwiseBlocksize), + 0, + make_tuple(arg.e_in_transpose_desc_), + make_tuple(arg.e_out_transpose_desc_), + make_tuple(p_e_in_grid), + make_tuple(p_e_out_grid), + arg.elementwise_block_2_ctile_map_transpose_e_, + element_wise::PassThrough{}); + } + } + return avg_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const index_t G = arg.b_g_k_c_xs_lengths_[I0]; + const index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // Move this to runtime check to align Conv instances + // with Conv Multiple D instances + if constexpr(isMultiABD) + { + return false; + } + + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t SpatialDim = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + if constexpr(is_NGCHW_GKCYX_NGKHW() || + is_NGCDHW_GKCZYX_NGKDHW()) + { + if((G * C) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if((G * K) % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + const index_t input_spatial_acum = ck::accumulate_n( + arg.a_g_n_c_wis_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + const index_t output_spatial_acum = ck::accumulate_n( + arg.e_g_n_k_wos_lengths_.begin() + I3, NDimSpatial, 1, std::multiplies<>()); + + if(input_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if(output_spatial_acum % CDEBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + + if(!arg.p_workspace_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Warning: Workspace for " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3::Argument is not " + "allocated, use SetWorkSpacePointer." + << std::endl; + } + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + if(!(arg.a_out_transpose_desc_.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + arg.e_in_transpose_desc_.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // Gridwise gemm v3 doesn't verify descriptors size + if(!arg.conv_to_gemm_transformer_.AreDescriptorsSmallerThan2GB()) + { + return false; + } + + // check Gridwise GEMM + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_.GetLength(I1); + const index_t GemmK = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + typename GridwiseGemm::Argument gemm_arg{ + nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, I1 /*KBatch*/}; + + return GridwiseGemm::CheckValidity(gemm_arg); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_as, + const void* p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_as, + const void* p_bs, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_as, + p_bs, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp new file mode 100755 index 000000000000..face627e1fb8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Grouped Convolution Forward: +// input : input image A[G, N, C, Hi, Wi], +// input : weight B[G, K, C, Y, X], +// input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ... +// output : output image E[G, N, K, Ho, Wo] +// output : R0[G, N, Ho, Wo], R1[G, N, Ho, Wo], ... +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ... +// R0 = r_op0(Q0), R1 = r_op1(Q1), ... +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGroupedConvFwdMultipleDMultipleR : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + virtual std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp new file mode 100755 index 000000000000..ec1a05366e56 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -0,0 +1,1107 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" +#include "ck/library/utility/numeric.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +struct ComputePtrOffsetOfStridedBatch +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, + index_t BatchStrideB, + Array BatchStrideDs, + index_t BatchStrideE, + Array BatchStrideRs) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE), + BatchStrideRs_(BatchStrideRs) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideA_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideB_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = g_idx * static_cast(BatchStrideDs_[i]); }); + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(BatchStrideE_); + } + + __host__ __device__ constexpr auto GetRsPtrOffset(index_t g_idx) const + { + Array rs_offset; + static_for<0, NumRTensor, 1>{}( + [&](auto i) { rs_offset(i) = g_idx * static_cast(BatchStrideRs_[i]); }); + return rs_offset; + } + + index_t BatchStrideA_; + index_t BatchStrideB_; + Array BatchStrideDs_; + index_t BatchStrideE_; + Array BatchStrideRs_; +}; + +/* + * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM. + * + * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix + * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly + * strided batched, but we can easily extend to other layouts. The returned offset can be either \p + * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB + * limitations. + * + * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and + * returns the 2D index of the tile that it computes. \see + * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run(). + * + * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2 + * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid + * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link + * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for + * \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the + * computing of pointer offset into \p ComputePtrOffsetOfStridedBatch. + * + * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes. + * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to + * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion). + * + */ +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batch_gemm_multiple_d_xdl_cshuffle( + const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + RsPointer p_rs_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const QsElementwiseOperation qs_element_op, + const RsElementwiseOperation rs_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, + const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const RsGridDescriptor_MBlock_MPerBlock rs_grid_desc_mblock_mperblock, + const Block2ETileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + const auto rs_batch_offset = compute_ptr_offset_of_batch.GetRsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + RsPointer p_rs_grid_grp; + + static constexpr index_t NumRTensor = RsGridDescriptor_MBlock_MPerBlock::Size(); + + static_for<0, NumRTensor, 1>{}( + [&](auto i) { p_rs_grid_grp(i) = p_rs_grid[i] + rs_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_rs_grid_grp, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + rs_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = p_rs_grid; + ignore = batch_count; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = rs_grid_desc_mblock_mperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = qs_element_op; + ignore = rs_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +} // namespace + +template +struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle + : public DeviceGroupedConvFwdMultipleDMultipleR +{ + using DeviceOp = DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + template + static auto GetPaddedRGridDescriptor(Descriptor descriptor, index_t MRaw) + { + const auto M = math::integer_divide_ceil(MRaw, MPerBlock) * MPerBlock; + const auto MPad = M - MRaw; + + if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M + return transform_tensor_descriptor( + descriptor, + make_tuple(make_right_pad_transform(descriptor, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + } + else + { + // not pad M + return descriptor; + } + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + static auto + MakeRGridDescriptor_M(const std::array& r_g_n_wos_lengths, + const std::array& /* r_g_n_wos_strides */) + { + const index_t N = r_g_n_wos_lengths[1]; + + const index_t NHoWo = + N * ck::accumulate_n( + r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>()); + + const auto r_grid_desc_mraw = make_naive_tensor_descriptor_packed(make_tuple(NHoWo)); + + return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); + } + + template || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + static auto MakeRGridDescriptor_M(const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides) + { + const index_t N = r_g_n_wos_lengths[1]; + + const index_t WoStride = r_g_n_wos_strides[NDimSpatial + 2]; + + const index_t NHoWo = + N * ck::accumulate_n( + r_g_n_wos_lengths.begin() + 2, NDimSpatial, 1, std::multiplies<>()); + + const auto r_grid_desc_mraw = + make_naive_tensor_descriptor(make_tuple(NHoWo), make_tuple(WoStride)); + + return GetPaddedRGridDescriptor(r_grid_desc_mraw, NHoWo); + } + + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using RGridDesc_M = remove_cvref_t({}, {}))>; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + ReduceAccDataType, + RsDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + ThreadReduceOperations, + InMemoryDataOperationEnum::Set, + RsGlobalMemoryDataOperation, + AGridDesc_M_K, + BGridDesc_N_K, + EGridDesc_M_N, + RGridDesc_M, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + RThreadTransferDstScalarPerVector_MPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + p_rs_grid_{}, // FIXME + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(conv_to_gemm_transformer_)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(conv_to_gemm_transformer_)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + r_grid_desc_m_{ + DeviceOp::MakeRGridDescriptor_M(r_g_n_wos_lengths, r_g_n_wos_strides)}, + a_grid_desc_ak0_m_ak1_{ + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{ + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + rs_grid_desc_mblock_mperblock_{}, + block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + qs_element_op_{qs_element_op}, + rs_element_op_{rs_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + e_grid_desc_m_n_, + r_grid_desc_m_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + // D desc + ds_grid_desc_m_n_(i) = + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_(i) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_(i)); + }); + + // populate pointer for Rs + static_for<0, NumRTensor, 1>{}([&](auto i) { + using RDataType = remove_cvref_t>; + + // R pointer + p_rs_grid_(i) = static_cast(p_rs[i]); + + rs_grid_desc_mblock_mperblock_(i) = + GridwiseGemm::MakeRGridDescriptor_MBlock_MPerBlock(r_grid_desc_m_); + }); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + typename GridwiseGemm::RsGridPointer p_rs_grid_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + EGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + RGridDesc_M r_grid_desc_m_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor> + ds_grid_desc_mblock_mperblock_nblock_nperblock_; // FIXME: Ds desc may be of different + // type from E + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + StaticallyIndexedArray + rs_grid_desc_mblock_mperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + QsElementwiseOperation qs_element_op_; + RsElementwiseOperation rs_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * + arg.a_g_n_c_wis_lengths_[0]; // Group count + + const auto K = + arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_batch_gemm_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + typename GridwiseGemm::RsGridPointer, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + QsElementwiseOperation, + RsElementwiseOperation, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + ck::StaticallyIndexedArray< + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + NumDTensor>, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ck::StaticallyIndexedArray< + typename GridwiseGemm::RGridDescriptor_MBlock_MPerBlock, + NumRTensor>, + Block2ETileMap, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.p_rs_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.qs_element_op_, + arg.rs_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.rs_grid_desc_mblock_mperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(get_device_name() == "gfx908") + { + if constexpr(!(is_same_v || is_same_v || + is_same_v)) + { + return false; + } + } + else if(ck::is_lds_direct_load_supported()) + { + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of R + if constexpr(!(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v)) + { + return false; + } + + // check Gridwise GEMM + return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_, + arg.b_grid_desc_n_k_, + arg.e_grid_desc_m_n_, + arg.r_grid_desc_m_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + p_rs, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + r_g_n_wos_lengths, + r_g_n_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + std::array p_rs, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& r_g_n_wos_lengths, + const std::array& r_g_n_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + p_rs, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + r_g_n_wos_lengths, + r_g_n_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + qs_element_op, + rs_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp new file mode 100755 index 000000000000..1c9f6b00948b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,1010 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// +// @brief Device Convolution operation. +// +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// Assume: +// AK1 == BK1 +template +struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + // K1 = Max Vector Access Pixels + static constexpr auto K1Number = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + + static constexpr auto AEnableLds_auto = NWaves == 1 ? false : true; + static constexpr auto BEnableLds_auto = MWaves == 1 ? false : true; + + // If true, LDS is used unconditionally + static constexpr auto AEnableLds_manu = true; + static constexpr auto BEnableLds_manu = true; + + static constexpr auto AEnableLds = + AEnableLds_auto || AEnableLds_manu || (NumGemmKPrefetchStage > 1); + static constexpr auto BEnableLds = + BEnableLds_auto || BEnableLds_manu || (NumGemmKPrefetchStage > 1); + + using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto MakeAGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + const auto M = in_gemmm_gemmk_desc.GetLength(I0); + const auto K = in_gemmm_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(AEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto A_KRow = 2; + constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number; + const auto A_KWmma = K / WmmaK; + + const auto M0 = M / MPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + A_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(M0 * MRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + template + static auto MakeBGridDescriptor(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + const auto N = wei_gemmn_gemmk_desc.GetLength(I0); + const auto K = wei_gemmn_gemmk_desc.GetLength(I1); + assert(K % K1 == 0); + + if constexpr(BEnableLds) + { + const index_t K0 = K / K1; + + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + constexpr auto B_KRow = 2; + constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number; + const auto B_KWmma = K / WmmaK; + + const auto N0 = N / NPerBlock; + // 0 1 0 1 2 3 4 5 6 + // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1 + return transform_tensor_descriptor( + wei_gemmn_gemmk_desc, + make_tuple(make_unmerge_transform(make_tuple( + B_KWmma, Number{}, Number{}, K1Number)), + make_unmerge_transform( + make_tuple(N0 * NRepeat, Number{}, Number{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer); + }, + Number{}); + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer; + using AGridDesc = + decltype(DeviceOp::MakeAGridDescriptor(dummy_conv_to_gemm_transformer)); + using BGridDesc = + decltype(DeviceOp::MakeBGridDescriptor(dummy_conv_to_gemm_transformer)); + using DsGridDesc_M_N = + remove_cvref_t; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + // GridwiseOp + using GridwiseOp = GridwiseGemmMultipleD_Wmma< + // DataType Family + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + // InMemory Data Descriptor + AGridDesc, + BGridDesc, + DsGridDesc_M_N, + EGridDesc_M_N, + // ElementwiseOp Family + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + // Tiling Family + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + K1, + MRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + AEnableLds, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BEnableLds, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + NumGemmKPrefetchStage, + LoopSched, + PipelineVer>; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + index_t M01, + index_t N01, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_c_wis_lengths[0]}, + conv_to_gemm_transformer_{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_)}, + a_grid_desc_{DeviceOp::MakeAGridDescriptor(conv_to_gemm_transformer_)}, + b_grid_desc_{DeviceOp::MakeBGridDescriptor(conv_to_gemm_transformer_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(e_grid_desc_m_n_, M01, N01)}, + compute_ptr_offset_of_batch_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths}, + ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + // A/B/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + // using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds[i]); + + // D batch stride + compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0]; + }); + + // D desc + ds_grid_desc_m_n_ = generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths[i], + ds_g_n_k_wos_strides[i], + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + return DeviceOp::MakeEGridDescriptor_M_N(conv_to_gemm_transformer_d); + }, + Number{}); + + // populate desc for Ds/E + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseOp::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + index_t num_group_; + + ConvToGemmFwdTransformer conv_to_gemm_transformer_; + + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc a_grid_desc_; + BGridDesc b_grid_desc_; + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + typename GridwiseOp::DefaultBlock2CTileMap block_2_etile_map_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array, NumDTensor> ds_g_n_k_wos_lengths_; + std::array, NumDTensor> ds_g_n_k_wos_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_; + + const auto K = [&]() { + if constexpr(AEnableLds) + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2); + } + else + { + return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) * + arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6); + } + }(); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle< + GridwiseOp, + ADataType, + BDataType, + typename GridwiseOp::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc, + DeviceOp::BGridDesc, + typename GridwiseOp::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseOp::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + remove_reference_t, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_g_n_c_wis_lengths_[0], // Group count + arg.a_grid_desc_, + arg.b_grid_desc_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_etile_map_, + arg.compute_ptr_offset_of_batch_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + // check device + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + else + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t C = arg.a_g_n_c_wis_lengths_[2]; + + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + const index_t C = arg.b_g_k_c_xs_lengths_[2]; + + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of Ds + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DLayout = remove_cvref_t>; + + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v) + { + const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; + + if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + valid = false; + } + } + else + { + valid = false; + } + }); + + if(!valid) + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + const index_t K = arg.e_g_n_k_wos_lengths_[2]; + + if(!(K % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + // check Gridwise GEMM + return GridwiseOp::CheckValidity(arg.a_grid_desc_, + arg.b_grid_desc_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + std::array a_g_n_c_wis_lengths_i32; + std::array a_g_n_c_wis_strides_i32; + std::array b_g_k_c_xs_lengths_i32; + std::array b_g_k_c_xs_strides_i32; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i32; + std::array, NumDTensor> ds_g_n_k_wos_strides_i32; + std::array e_g_n_k_wos_lengths_i32; + std::array e_g_n_k_wos_strides_i32; + std::array conv_filter_strides_i32; + std::array conv_filter_dilations_i32; + std::array input_left_pads_i32; + std::array input_right_pads_i32; + + array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i32, conv_filter_strides); + array_convert(conv_filter_dilations_i32, conv_filter_dilations); + array_convert(input_left_pads_i32, input_left_pads); + array_convert(input_right_pads_i32, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i32, + a_g_n_c_wis_strides_i32, + b_g_k_c_xs_lengths_i32, + b_g_k_c_xs_strides_i32, + ds_g_n_k_wos_lengths_i32, + ds_g_n_k_wos_strides_i32, + e_g_n_k_wos_lengths_i32, + e_g_n_k_wos_strides_i32, + conv_filter_strides_i32, + conv_filter_dilations_i32, + input_left_pads_i32, + input_right_pads_i32, + 1, + 1, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Wmma_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << K1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "BEnableLds: " + << BEnableLds << ", " + << "ABlockTransferSrcScalarPerVector: " + << ABlockTransferSrcScalarPerVector << ", " + << "BBlockTransferSrcScalarPerVector: " + << BBlockTransferSrcScalarPerVector; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..09ff8207574b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,138 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// +// @brief Device Convolution operation. +// @note This structure is deprecated (left for backwards compatibility). Please use +// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle. +// Supports: +// @li Forward convolution with up to 3 spatial dimentions +// @li Input tensor in GNWC data format +// @li Weight tensor in GKXC data format +// @li Output tensor in GNWK data format +// +// 1D: +// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C] +// 2D: +// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C] +// 3D: +// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C] +// +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> +using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ConvForwardSpecialization, + GemmSpec, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + AComputeDataType, + BComputeDataType, + LoopSched>; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp new file mode 100755 index 000000000000..998836795982 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -0,0 +1,1073 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle( + Array gemm_desc_kernel_args, + const index_t gemms_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op, + const ComputePtrOffset compute_ptr_offset_of_groups, + const ComputePtrOffset compute_ptr_offset_of_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); + + const long_index_t a_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx)); + const long_index_t b_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx)); + const long_index_t e_group_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx)); + + const long_index_t a_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); + const long_index_t e_n_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ && + block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, + gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, + Tuple<>{}, + gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + Tuple<>{}, + gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_kernel_args[group_id].block_2_etile_map_); +#else + ignore = gemm_desc_kernel_args; + ignore = gemms_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = compute_ptr_offset_of_groups; + ignore = compute_ptr_offset_of_n; +#endif +} + +} // namespace + +template +using is_tuple = decltype(std::declval().IsTuple()); + +template ::value, + Number<0>, + ADataType>()), // ComputeType is InputType by default (first + // in tuple for MultiAB), unpack if tuple was + // passed + typename BComputeDataType = AComputeDataType, + LoopScheduler LoopSched = make_default_loop_scheduler()> +struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor + : public DeviceGroupedConvFwdMultipleABD +{ + using DeviceOp = DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t MaxGemmsNum = 32; + static constexpr bool DoElementwiseBeforeCShuffle = + NumDTensor == 0 && is_same_v && + !is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ConvToGemmFwdTransformerIndexT = TransformConvFwdToGemm; + + using ConvToGemmFwdTransformerLongIndexT = TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + template + static auto + MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + + return in_gemmm_gemmk_desc; + } + + template + static auto + MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto wei_gemmnraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeBDescriptor_N_K(); + + const auto wei_gemmn_gemmk_desc = + matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc); + + return wei_gemmn_gemmk_desc; + } + + template + static auto + MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT& conv_to_gemm_transformer) + { + const auto out_gemmmraw_gemmnraw_desc = + conv_to_gemm_transformer.template MakeCDescriptor_M_N(); + + const auto out_gemmm_gemmn_desc = + matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc); + + return out_gemmm_gemmn_desc; + } + + // desc for problem definition + constexpr static ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer; + using AGridDesc_M_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using BGridDesc_N_K = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + using EGridDesc_M_N = + remove_cvref_t(dummy_conv_to_gemm_transformer))>; + + static auto + GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base, + const ADataType* a_grid_ptr_base, + EDataType* c_grid_ptr_base) + { + // Max number of splits + // We need to use it to avoid infinity loop + constexpr index_t max_split_numbers = MaxGemmsNum / 2; + // Arrays to store transformers with smaller descs than 2GB + Array conv_to_gemm_transformers_arr; + Array a_grid_ptrs_arr; + Array c_grid_ptrs_arr; + // Queue for spliting + std::queue conv_to_gemm_transformers_queue( + {conv_to_gemm_transformer_base}); + std::queue a_grid_ptrs_queue({a_grid_ptr_base}); + std::queue c_grid_ptrs_queue({c_grid_ptr_base}); + + index_t gemms_number = 0; + index_t split_numbers = 0; + // Algorithm: + // While queue is not empty: + // 1. Get transformer from queue. + // 2. If descs are smaller than 2GB push to result array. + // 3. If descs are bigger than 2GB split into left and right transformer. + // and push the both into the queue. + while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers && + gemms_number < MaxGemmsNum) + { + // Get transformer from the queue + const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front(); + const ADataType* a_grid_ptr = a_grid_ptrs_queue.front(); + EDataType* c_grid_ptr = c_grid_ptrs_queue.front(); + + // Check if convolution not exceed 2GB + if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB()) + { + // If yes, push into result array + conv_to_gemm_transformers_arr(gemms_number) = + ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer}; + a_grid_ptrs_arr(gemms_number) = a_grid_ptr; + c_grid_ptrs_arr(gemms_number) = c_grid_ptr; + gemms_number++; + } + else + { + // If no, split into left and right convolutions + ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part; + const ADataType* a_grid_right_ptr; + EDataType* c_grid_right_ptr; + + ck::tie(conv_to_gemm_transformers_left_part, + conv_to_gemm_transformers_right_part, + a_grid_right_ptr, + c_grid_right_ptr) = + conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, c_grid_ptr); + + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part); + conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part); + // Left offsets remain the same + a_grid_ptrs_queue.push(a_grid_ptr); + a_grid_ptrs_queue.push(a_grid_right_ptr); + c_grid_ptrs_queue.push(c_grid_ptr); + c_grid_ptrs_queue.push(c_grid_right_ptr); + split_numbers++; + } + // Remove from the queue + conv_to_gemm_transformers_queue.pop(); + a_grid_ptrs_queue.pop(); + c_grid_ptrs_queue.pop(); + } + + const bool is_split_valid = conv_to_gemm_transformers_queue.empty(); + + return ck::make_tuple(conv_to_gemm_transformers_arr, + a_grid_ptrs_arr, + c_grid_ptrs_arr, + gemms_number, + is_split_valid); + } + +#define GridwiseGemmTemplateParameters \ + ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ + AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + AComputeDataType, DoElementwiseBeforeCShuffle + // Use appropriate gridwise gemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; + + // desc for blockwise copy + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // block-to-e-tile map + using Block2ETileMap = + remove_cvref_t; + // Structure for each gemm(conv) + struct GemmArgs + { + // pointers + const ADataType* a_ptr_; + const BDataType* b_ptr_; + EDataType* e_ptr_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, + const void* p_b, + const std::array& /*p_ds*/, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + /*ds_g_n_k_wos_lengths*/, + const std::array, NumDTensor>& + /*ds_g_n_k_wos_strides*/, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + : num_group_{static_cast(a_g_n_c_wis_lengths[0])}, + compute_ptr_offset_of_groups_{}, + compute_ptr_offset_of_n_{}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths}, + a_g_n_c_wis_strides_{a_g_n_c_wis_strides}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + b_g_k_c_xs_strides_{b_g_k_c_xs_strides}, + e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths}, + e_g_n_k_wos_strides_{e_g_n_k_wos_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + if constexpr(NumDTensor == 0) + { + // Perform grouped gemm, generate array of tranformer for convolution + Array conv_to_gemm_transformer_arr; + Array a_grid_ptrs; + Array c_grid_ptrs; + + ck::tie(conv_to_gemm_transformer_arr, + a_grid_ptrs, + c_grid_ptrs, + gemms_count_, + is_split_valid_) = + GenerateConvToGemmTransforms( + ConvToGemmFwdTransformerLongIndexT{a_g_n_c_wis_lengths_, + a_g_n_c_wis_strides_, + b_g_k_c_xs_lengths_, + b_g_k_c_xs_strides_, + e_g_n_k_wos_lengths_, + e_g_n_k_wos_strides_, + conv_filter_strides_, + conv_filter_dilations_, + input_left_pads_, + input_right_pads_}, + static_cast(p_a), + static_cast(p_e)); + + grid_size_ = 0; + valid_gemms_count_ = 0; + + if(is_split_valid_) + { + // Create GemmArg for each gemm(conv) + for(index_t i = 0; i < gemms_count_; i++) + { + const AGridDesc_M_K a_grid_desc_m_k{ + DeviceOp::MakeAGridDescriptor_M_K( + conv_to_gemm_transformer_arr[i])}; + const BGridDesc_N_K b_grid_desc_n_k{ + DeviceOp::MakeBGridDescriptor_N_K( + conv_to_gemm_transformer_arr[i])}; + const auto e_grid_desc_m_n = DeviceOp::MakeEGridDescriptor_M_N( + conv_to_gemm_transformer_arr[i]); + + const auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + const index_t grid_size_grp = + block_2_etile_map.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + Tuple<>{}, + e_grid_desc_m_n, + block_2_etile_map)) + { + + gemm_desc_kernel_args_(valid_gemms_count_) = GemmArgs{ + a_grid_ptrs[i], + static_cast(p_b), + c_grid_ptrs[i], + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k), + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k), + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n), + block_2_etile_map, + BlockStart, + BlockEnd}; + + valid_gemms_count_++; + } + } + // N is the same for all convs + conv_N_per_block_ = static_cast(conv_to_gemm_transformer_arr[I0].N_); + } + + // Strides for G and N remain the same + compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0]; + + compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_; + compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_; + } + } + + void Print() const + { + for(index_t i = 0; i < valid_gemms_count_; i++) + { + std::cout << "A[AK0, M, AK1]: " << gemm_desc_kernel_args_[i].a_grid_desc_ak0_m_ak1_ + << std::endl; + std::cout << "B[BK0, N, BK1]: " << gemm_desc_kernel_args_[i].b_grid_desc_bk0_n_bk1_ + << std::endl; + std::cout + << "E[MBlock, MPerBlock, NBlock, NPerBlock]: " + << gemm_desc_kernel_args_[i].e_grid_desc_mblock_mperblock_nblock_nperblock_ + << std::endl; + } + } + + index_t num_group_; + index_t conv_N_per_block_; + + Array gemm_desc_kernel_args_; + + index_t grid_size_; + index_t gemms_count_; + index_t valid_gemms_count_; + + bool is_split_valid_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_groups_; + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_n_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // for checking IsSupportedArgument() + std::array a_g_n_c_wis_lengths_; + std::array a_g_n_c_wis_strides_; + std::array b_g_k_c_xs_lengths_; + std::array b_g_k_c_xs_strides_; + std::array e_g_n_k_wos_lengths_; + std::array e_g_n_k_wos_strides_; + std::array conv_filter_strides_; + std::array conv_filter_dilations_; + std::array input_left_pads_; + std::array input_right_pads_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const DeviceOp::Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if constexpr(NumDTensor == 0) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const index_t num_workgroups_per_Conv_N = + arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_; + + const index_t gdx = arg.grid_size_; + const index_t gdy = arg.num_group_; + const index_t gdz = num_workgroups_per_Conv_N; + + // K is constant for all gemms + const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + const auto kernel = + kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle< + GridwiseGemm, + MaxGemmsNum, + GemmArgs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + ComputePtrOffsetOfStridedBatch, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg.gemm_desc_kernel_args_, + arg.gemms_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.compute_ptr_offset_of_groups_, + arg.compute_ptr_offset_of_n_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + else + { + return 0.f; + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + namespace ctc = tensor_layout::convolution; + + const long_index_t K = arg.b_g_k_c_xs_lengths_[I1]; + const long_index_t C = arg.b_g_k_c_xs_lengths_[I2]; + // Move this to runtime check to align Conv instances + // with Conv Multiple D instances + if constexpr(NumDTensor != 0) + { + return false; + } + + // Check if all descs are valid + if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_)) + { + return false; + } + // check device + if(get_device_name() == "gfx908") + { + // FIXME: re-enable fp64 when SWDEV-335738 is fixed + if constexpr(!(is_same_v || is_same_v)) + { + return false; + } + } + if(!ck::is_xdl_supported()) + { + return false; + } + + // check ConvolutionForwardSpecialization + if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t ConvStride = arg.conv_filter_strides_[i]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == + ConvolutionForwardSpecialization::Filter1x1Pad0) + { + // check if it's 1x1 conv + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t X = arg.b_g_k_c_xs_lengths_[i + 3]; + const index_t LeftPad = arg.input_left_pads_[i]; + const index_t RightPad = arg.input_right_pads_[i]; + + if(!(X == 1 && LeftPad == 0 && RightPad == 0)) + { + return false; + } + } + } + else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3) + { + if(C != 1) + { + return false; + } + for(index_t i = 0; i < NDimSpatial; ++i) + { + const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3]; + + if(filter_spatial_dim != I3) + { + return false; + } + } + if constexpr(!is_NSpatialGC_GKSpatial_NSpatialGK()) + { + return false; + } + } + + // check vector access of A + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + // Check access per C + if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of B + // FIXME: layout + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + + { + if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0)) + { + return false; + } + } + else + { + return false; + } + + // check vector access of E + if constexpr(is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v) + { + if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + } + else + { + return false; + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + std::array a_g_n_c_wis_lengths_i64; + std::array a_g_n_c_wis_strides_i64; + std::array b_g_k_c_xs_lengths_i64; + std::array b_g_k_c_xs_strides_i64; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i64; + std::array, NumDTensor> ds_g_n_k_wos_strides_i64; + std::array e_g_n_k_wos_lengths_i64; + std::array e_g_n_k_wos_strides_i64; + std::array conv_filter_strides_i64; + std::array conv_filter_dilations_i64; + std::array input_left_pads_i64; + std::array input_right_pads_i64; + + array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i64, conv_filter_strides); + array_convert(conv_filter_dilations_i64, conv_filter_dilations); + array_convert(input_left_pads_i64, input_left_pads); + array_convert(input_right_pads_i64, input_right_pads); + + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i64, + a_g_n_c_wis_strides_i64, + b_g_k_c_xs_lengths_i64, + b_g_k_c_xs_strides_i64, + ds_g_n_k_wos_lengths_i64, + ds_g_n_k_wos_strides_i64, + e_g_n_k_wos_lengths_i64, + e_g_n_k_wos_strides_i64, + conv_filter_strides_i64, + conv_filter_dilations_i64, + input_left_pads_i64, + input_right_pads_i64, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + std::array a_g_n_c_wis_lengths_i64; + std::array a_g_n_c_wis_strides_i64; + std::array b_g_k_c_xs_lengths_i64; + std::array b_g_k_c_xs_strides_i64; + std::array, NumDTensor> ds_g_n_k_wos_lengths_i64; + std::array, NumDTensor> ds_g_n_k_wos_strides_i64; + std::array e_g_n_k_wos_lengths_i64; + std::array e_g_n_k_wos_strides_i64; + std::array conv_filter_strides_i64; + std::array conv_filter_dilations_i64; + std::array input_left_pads_i64; + std::array input_right_pads_i64; + + array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths); + array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides); + array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths); + array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides); + for(index_t d = 0; d < NumDTensor; d++) + { + array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]); + array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]); + } + array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths); + array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides); + array_convert(conv_filter_strides_i64, conv_filter_strides); + array_convert(conv_filter_dilations_i64, conv_filter_dilations); + array_convert(input_left_pads_i64, input_left_pads); + array_convert(input_right_pads_i64, input_right_pads); + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths_i64, + a_g_n_c_wis_strides_i64, + b_g_k_c_xs_lengths_i64, + b_g_k_c_xs_strides_i64, + ds_g_n_k_wos_lengths_i64, + ds_g_n_k_wos_strides_i64, + e_g_n_k_wos_lengths_i64, + e_g_n_k_wos_strides_i64, + conv_filter_strides_i64, + conv_filter_dilations_i64, + input_left_pads_i64, + input_right_pads_i64, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + const std::array& p_ds, + void* p_e, + const std::array& a_g_n_c_wis_lengths, + const std::array& a_g_n_c_wis_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>& + ds_g_n_k_wos_lengths, + const std::array, NumDTensor>& + ds_g_n_k_wos_strides, + const std::array& e_g_n_k_wos_lengths, + const std::array& e_g_n_k_wos_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op) override + { + + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << getConvForwardSpecializationString(ConvForwardSpecialization) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CDEBlockTransferScalarPerVector_NPerBlock << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp new file mode 100755 index 000000000000..5de429f9e530 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// 1d +template +constexpr bool is_NWGC_GKXC_NWGK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_GNWC_GKXC_GNWK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCW_GKXC_NGKW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +// 2d +template +constexpr bool is_NHWGC_GKYXC_NHWGK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_GNHWC_GKYXC_GNHWK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCHW_GKYXC_NGKHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCHW_GKCYX_NGKHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCHW_NGKHW() +{ + return is_same_v && + is_same_v; +} + +// 3d +template +constexpr bool is_NDHWGC_GKZYXC_NDHWGK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_GNDHWC_GKZYXC_GNDHWK() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCDHW_GKZYXC_NGKDHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCDHW_GKCZYX_NGKDHW() +{ + return is_same_v && + is_same_v && + is_same_v; +} + +template +constexpr bool is_NGCDHW_NGKDHW() +{ + return is_same_v && + is_same_v; +} + +template +constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK() +{ + return is_NWGC_GKXC_NWGK() || + is_NHWGC_GKYXC_NHWGK() || + is_NDHWGC_GKZYXC_NDHWGK(); +} + +template +constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK() +{ + return is_GNWC_GKXC_GNWK() || + is_GNHWC_GKYXC_GNHWK() || + is_GNDHWC_GKZYXC_GNDHWK(); +} + +template +constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial() +{ + return is_NGCW_GKXC_NGKW() || + is_NGCHW_GKYXC_NGKHW() || + is_NGCDHW_GKZYXC_NGKDHW(); +} + +template +struct ComputePtrOffsetOfStridedBatch +{ +}; + +template +struct ComputePtrOffsetOfStridedBatch 1 || NumBTensor > 1)>> +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(Array& BatchStrideAs, + Array& BatchStrideBs, + Array& BatchStrideDs, + long_index_t BatchStrideE) + : BatchStrideA_(BatchStrideAs), + BatchStrideB_(BatchStrideBs), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr auto GetAsPtrOffset(index_t g_idx) const + { + Array as_offset; + static_for<0, NumATensor, 1>{}( + [&](auto i) { as_offset(i) = static_cast(g_idx) * BatchStrideA_[i]; }); + return as_offset; + } + + __host__ __device__ constexpr auto GetBsPtrOffset(index_t g_idx) const + { + Array bs_offset; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { bs_offset(i) = static_cast(g_idx) * BatchStrideB_[i]; }); + return bs_offset; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + Array BatchStrideA_; + Array BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +template +struct ComputePtrOffsetOfStridedBatch> +{ + ComputePtrOffsetOfStridedBatch() = default; + + ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA, + long_index_t BatchStrideB, + Array BatchStrideDs, + long_index_t BatchStrideE) + : BatchStrideA_(BatchStrideA), + BatchStrideB_(BatchStrideB), + BatchStrideDs_(BatchStrideDs), + BatchStrideE_(BatchStrideE) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideA_; + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideB_; + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + Array ds_offset; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { ds_offset(i) = static_cast(g_idx) * BatchStrideDs_[i]; }); + return ds_offset; + } + + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + // alias for kernels without multiple D + [[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * BatchStrideE_; + } + + long_index_t BatchStrideA_; + long_index_t BatchStrideB_; + Array BatchStrideDs_; + long_index_t BatchStrideE_; + long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D +}; + +template +constexpr static auto GetNumABTensors() +{ + if constexpr(isTuple) + { + return Number{}; + } + else + { + return Number<1>{}; + } +} + +template +constexpr static auto GetAGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::AsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto GetBGridPointer() +{ + if constexpr(isTuple) + { + return typename GridwiseGemm::BsGridPointer{}; + } + else + { + return Tuple{}; + } +} + +template +constexpr static auto UnpackDataType() +{ + if constexpr(isTuple) + { + // unpack if tuple + return tuple_element_t{}; + } + else + { + // if no, return Type + return Type{}; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp new file mode 100755 index 000000000000..21afc0604068 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp @@ -0,0 +1,850 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t KBatch = 1; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideAs = gemm_desc_ptr[group_id].StrideAs; + const auto StrideBs = gemm_desc_ptr[group_id].StrideBs; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumATensor = GridwiseGemm::AsGridPointer::Size(); + constexpr auto NumBTensor = GridwiseGemm::BsGridPointer::Size(); + constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_as_grid[i]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_bs_grid[i]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + GridwiseGemm:: + template Run( + p_as_grid_, + p_bs_grid_, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + block_2_etile_map); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = grid_size_grp; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK + : public DeviceGroupedGemmMultiABDFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t NumGemmKPrefetchStage = 1; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleABD_xdl_cshuffle< + AsDataType, + BsDataType, + ComputeType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple( + // idx_bot[Number<0>{}], + idx_bot[Number<1>{}], + idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + struct GemmBiasTransKernelArg + { + // pointers + std::array as_ptr_; + std::array bs_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + std::array StrideAs_; + std::array StrideBs_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t) {} + + Argument(std::vector>&, + std::vector>&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + // pointer + std::array p_as_grid; + std::array p_bs_grid; + std::array p_ds_grid; + + static_for<0, NumATensor, 1>{}([&](auto j) { p_as_grid[j] = nullptr; }); + static_for<0, NumBTensor, 1>{}([&](auto j) { p_bs_grid[j] = nullptr; }); + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + + const index_t StrideE = gemm_descs[i].stride_C_; + + if(gemm_descs[i].stride_As_.size() != NumATensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_As_.size() does not match NumATensor"); + } + + static_for<0, NumATensor, 1>{}( + [&](auto j) { StrideAs[j] = gemm_descs[i].stride_As_[j]; }); + + if(gemm_descs[i].stride_Bs_.size() != NumBTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Bs_.size() does not match NumBTensor"); + } + + static_for<0, NumBTensor, 1>{}( + [&](auto j) { StrideBs[j] = gemm_descs[i].stride_Bs_[j]; }); + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + static_for<0, NumDTensor, 1>{}( + [&](auto j) { StrideDs[j] = gemm_descs[i].stride_Ds_[j]; }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + p_as_grid, + p_bs_grid, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_ = 1; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + if(GridwiseGemm::CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i].K_) != + has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + const auto kernel = kernel_grouped_gemm_xdl_fixed_nk< + GridwiseGemm, + GroupedGemmMultiABDKernelArgument, + GemmSpec, + AsLayout, + BsLayout, + DsLayout, + ELayout, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by vector + // load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector>& p_As, + std::vector>& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + static void SetElementwiseOps(Argument& arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + arg.a_element_op_ = a_element_op; + arg.b_element_op_ = b_element_op; + arg.c_element_op_ = c_element_op; + } + + static void SetDeviceKernelArgs(Argument& arg, const void* kernel_args) + { + arg.grouped_gemm_kernel_args_dev = kernel_args; + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, const void* kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), kernel_args); + } + + void SetElementwiseOps(BaseArgument* p_arg, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) const override + { + + SetElementwiseOps( + *dynamic_cast(p_arg), a_element_op, b_element_op, c_element_op); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * + sizeof(GroupedGemmMultiABDKernelArgument); + } + +#if 0 + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = *dynamic_cast(p_arg); + + return arg.group_count_ * arg.barrier_size_grp_ * sizeof(uint32_t); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + p_arg_->p_workspace_ = p_workspace; + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_)); + } +#endif + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + return SetKBatch(*dynamic_cast(p_arg), k_batch); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp new file mode 100755 index 000000000000..10d8a4a44d23 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp @@ -0,0 +1,833 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_dl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \ + defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::Run(gemm_desc_ptr[group_id].a_ptr_, + gemm_desc_ptr[group_id].b_ptr_, + gemm_desc_ptr[group_id].ds_ptr_, + gemm_desc_ptr[group_id].e_ptr_, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + gemm_desc_ptr[group_id].a_grid_desc_k0_m0_m1_k1_, + gemm_desc_ptr[group_id].b_grid_desc_k0_n0_n1_k1_, + gemm_desc_ptr[group_id].ds_grid_desc_m0_m10_m11_n0_n10_n11_, + gemm_desc_ptr[group_id].e_grid_desc_m0_m10_m11_n0_n10_n11_, + gemm_desc_ptr[group_id].block_2_etile_map_, + integral_constant{}, + integral_constant{}); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif +} + +template && + is_same_v, + bool> = false> +struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm +{ + using DeviceOp = DeviceGroupedGemmMultipleD_Dl; + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto K1Number = Number{}; + + static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(M, PadM)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB) + { + assert(K % K1 == 0); + + const index_t K0 = K / K1; + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)); + using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)); + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // GridwiseGemm + using GridwiseGemm = + GridwiseGemmDlMultipleD_km_kn_mn; + + using AGridDesc_K0_M0_M1_K1 = + decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = + decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using DsGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})); + using EGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})); + + struct GroupedGemmBlock2ETileMap + { + using Block2ETileMap = + remove_cvref_t; + + GroupedGemmBlock2ETileMap() + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{}); + BlockStart_ = -1; + } + + GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(e_grid_desc_m_n); + BlockStart_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_2_etile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - BlockStart_)); + } + + // it's actually E-Tile + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const + { + return block_2_etile_map_.CheckValidity(e_grid_desc_m_n); + } + + Block2ETileMap block_2_etile_map_; + ck::index_t BlockStart_; + }; + + struct GemmKernelArg + { + // pointers + const ADataType* a_ptr_; + const BDataType* b_ptr_; + typename GridwiseGemm::DsGridPointer ds_ptr_; + EDataType* e_ptr_; + + // tensor descriptors for problem definiton + AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; + BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_; + BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_; + DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_; + EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_; + + // block-to-e-tile map + GroupedGemmBlock2ETileMap block_2_etile_map_; + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + gemm_kernel_host_args_{nullptr} + { + grid_size_ = 0; + + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + a_mtx_mraw_kraw_.emplace_back(M, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + typename GridwiseGemm::DsGridPointer p_ds_grid{}; + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + + // tensor descriptors for problem definiton + const auto a_grid_desc_k0_m_k1 = + DeviceOp::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); + const auto b_grid_desc_k0_n_k1 = + DeviceOp::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideE); + + if(GridwiseGemm::CheckValidity( + a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n)) + { + + const index_t grid_size_grp = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0) + .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_k0_m0_m1_k1 = + GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1); + const auto b_grid_desc_k0_n0_n1_k1 = + GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(b_grid_desc_k0_n_k1); + const auto ds_grid_desc_m0_m10_m11_n0_n10_n11 = + GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n); + const auto e_grid_desc_m0_m10_m11_n0_n10_n11 = + GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(e_grid_desc_m_n); + + gemm_desc_kernel_arg_.push_back( + GemmKernelArg{static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + ds_grid_desc_m0_m10_m11_n0_n10_n11, + e_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_etile_map, + BlockStart, + BlockEnd}); + } + } + } + + // private: + index_t group_count_; + index_t skipped_group_count_; + + // TODO: A,B element op is unused since gridwise_gemm_dl_v1r3 does NOT support prologue + // for the time being. + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + index_t grid_size_; + void* gemm_kernel_host_args_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) + { + auto K0 = arg.gemm_desc_kernel_arg_[0].a_grid_desc_k0_m_k1_.GetLength(I0); + bool all_has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); + bool all_has_double_tail_k_block_loop = + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " arg.a_grid_desc_k0_m_k1_{" + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I0) + << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I1) + << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_.GetLength(I2) + << "}" << std::endl; + + std::cout << ", arg.b_grid_desc_k0_n_k1_{" + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I0) + << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I1) + << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_.GetLength(I2) + << "}" << std::endl; + + std::cout << ", arg.e_grid_desc_m_n_{ " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, + arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting"); + } + + K0 = arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m0_m1_k1_.GetLength(I0); + bool not_all_has_main_k_block_loop_same = + all_has_main_k_block_loop xor GridwiseGemm::CalculateHasMainKBlockLoop(K0); + bool not_all_has_double_tail_k_block_loop_same = + all_has_double_tail_k_block_loop xor + GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); + + if(not_all_has_main_k_block_loop_same or not_all_has_double_tail_k_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for [main|double_tail]_k_block_loop! in " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) + { + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + hipGetErrorString(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmKernelArg), + hipMemcpyHostToDevice, + cpy_stream)); + hipGetErrorString(hipEventRecord(cpy_event, cpy_stream)); + hipGetErrorString(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. + { + hipGetErrorString( + hipMemcpyAsync(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + auto launch_kernel = [&](auto has_main_k_block_loop, + auto has_double_tail_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + constexpr bool has_double_loop = has_double_tail_k_block_loop.value; + + const auto kernel = kernel_grouped_gemm_multiple_d_dl; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + }; + + if(all_has_main_k_block_loop && all_has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(all_has_main_k_block_loop && !all_has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(!all_has_main_k_block_loop && all_has_double_tail_k_block_loop) + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if((ck::type_convert(arg.gemm_desc_kernel_arg_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + return false; + } + + if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() || + ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_k0_m_k1_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_k0_n_k1_, + arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_)) + { + return false; + } + } + return true; + } + else + { + return false; + } + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, cde_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleD_Dl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << K0PerBlock << ", " + << K1 << ", " + << M1PerThread << ", " + << N1PerThread << ", " + << KPerThread + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GemmKernelArg); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp new file mode 100755 index 000000000000..18872e38eaf2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp @@ -0,0 +1,1017 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include +#include "ck/utility/tuple.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include + +namespace ck { +namespace tensor_operation { +namespace device { + +template = false> +struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage + : public DeviceGroupedGemmSplitK +{ + using DeviceOp = DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + // TODO change GridwiseGEMM v2r4r2 to support separate AK1 & BK1 + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using WorkspaceDataType = float; + + // First stage GridwiseGEMM kernel. + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + WorkspaceDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + PassThrough, // CElementwiseOperation + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer, + ComputeDataType>; + + template + static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE)); + } + }(); + + if constexpr(GemmSpec == GemmSpecialization::MNPadding) + { + const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, PadM), make_right_pad_transform(N, PadN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeElementwiseInputSequence() + { + return generate_sequence_v2( + [&]([[maybe_unused]] auto i) constexpr { + return Number{}; + }, + Number{}); + } + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using EGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {})); + using DsGridPointer = decltype(MakeDsGridPointer()); + using CDGridDesc_M_N = decltype(concat_tuple(ck::Tuple{}, DsGridDesc_M_N{})); + using CDDataTypes = decltype(concat_tuple(ck::Tuple{}, DsGridPointer{})); + + using ElementwiseInputSequence = decltype(MakeElementwiseInputSequence()); + + static constexpr index_t ClusterLengthMPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + using GridwiseElementwise = + GridwiseElementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation, + BlockSize, + MPerBlock, + NPerBlock, + MPerBlock / ClusterLengthMPerBlock, + NPerBlock / ClusterLengthNPerBlock, + Sequence<0, 1>, + ElementwiseInputSequence, + ck::Sequence, + I1, + I1>; + + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using GemmKernelArgument = typename GridwiseGemm::Argument; + + struct GemmTransKernelArg + { + GemmKernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(GemmKernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + cde_element_op, + DefaultKBatch) + { + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t kbatch) + : K_BATCH{kbatch}, + group_count_{0}, + skipped_group_count_{0}, + grid_size_{0}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + p_Ds_{p_Ds} + { + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("Error! group_count_ != p_As/Bs/Ds/Es size"); + } + + gemm_kernel_args_.reserve(group_count_); + elementwise_c_grid_descs_m_n_.reserve(group_count_); + elementwise_d_grid_descs_m_n_.reserve(group_count_); + ds_grid_pointer_.reserve(group_count_); + group_grid_size_.reserve(group_count_); + e_ptrs_.reserve(group_count_); + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0 || N == 0 || K == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_e = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_e); + + DsGridDesc_M_N ds_grid_desc_m_n; + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + group_grid_size_.push_back(grid_size_grp); + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + std::array stride_ds; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "Error! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + stride_ds[j] = gemm_descs[i].stride_Ds_[j]; + }); + stride_Ds_.emplace_back(std::move(stride_ds)); + + // We first set E pointer to actual operation output, but later on + // when workspace will be set, this will be updated to workspace memory. + auto karg = GemmKernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_e, + m_padded, + n_padded, + k_padded, + k0_padded, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + + elementwise_c_grid_descs_m_n_.push_back(c_grid_desc_m_n); + elementwise_d_grid_descs_m_n_.push_back(ds_grid_desc_m_n); + ds_grid_pointer_.push_back(p_ds_grid); + // Store a copy of E pointers for elementwise kernel destination + e_ptrs_.push_back(p_Es[i]); + } + } + + /** + * @brief Set new kbatch value. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + group_grid_size_[i] = grid_size_grp; + karg.KPadded = k_padded; + karg.K0Padded = k0_padded; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + index_t tiles = (block_end - block_start) / K_BATCH; + std::cout << "block_start: " << block_start << "\n" + << "block_end: " << block_end << "\n" + << "tiles: " << tiles << std::endl + << std::endl; + + std::cout << "KPadded: " << karg.KPadded << std::endl + << "K0Padded: " << karg.K0Padded << std::endl + << "KBatch: " << karg.k_batch << std::endl + << "grid_size_: " << karg.KPadded << std::endl; + } + } + } + + void UpdateEPointers() + { + // set-up each group E pointer to it's designated workspace memory. + WorkspaceDataType* p_workspace = reinterpret_cast(p_workspace_); + std::size_t offset = 0; + + for(auto& arg : gemm_kernel_args_) + { + arg.karg_.p_c_grid = p_workspace + offset; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + offset += tiles * MPerBlock * NPerBlock; + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "block_start: " << arg.block_start_ << "\n" + << "block_end: " << arg.block_end_ << "\n" + << "tiles: " << tiles << "\n" + << "offset: " << offset << std::endl; + } + } + } + + std::size_t GetWorkspaceSizeBytes() const + { + std::size_t size_bytes{0}; + + for(const auto& arg : gemm_kernel_args_) + { + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + size_bytes += tiles * MPerBlock * NPerBlock * sizeof(WorkspaceDataType); + } + return size_bytes; + } + + std::size_t GetWorkspaceSize(std::size_t group) const + { + const auto& arg = gemm_kernel_args_[group]; + index_t tiles = (arg.block_end_ - arg.block_start_) / arg.karg_.k_batch; + return tiles * MPerBlock * NPerBlock; + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + index_t grid_size_; + // Pointer to device memory with GEMM kernel arguments. + void* p_dev_gemm_kargs_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + std::vector>& p_Ds_; + std::vector> stride_Ds_; + std::vector gemm_kernel_args_; + std::vector group_grid_size_; + + std::vector elementwise_c_grid_descs_m_n_; + std::vector elementwise_d_grid_descs_m_n_; + std::vector ds_grid_pointer_; + std::vector e_ptrs_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] dev_gemm_workspace The pointer to device memory for kernel auxiliary + /// workspace. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + void* dev_gemm_args, + void* dev_gemm_workspace, + const StreamConfig& stream_config = StreamConfig{}) + { + auto [all_have_kbatch_gt_one, all_have_main_k_block_loop] = + CheckArgument(arg, stream_config); + + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(dev_gemm_workspace == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + + if(all_have_main_k_block_loop) + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + else + { + ave_time = + DispatchKernel(arg, dev_gemm_args, dev_gemm_workspace, stream_config); + } + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, @see GetWorkSpaceSize and @see + /// SetDeviceKernelArgs, @see SetWorkSpacePointer on arg parameter to properly + /// allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_kargs_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(arg.p_workspace_ == nullptr) + { + std::ostringstream err; + err << "The gemm workspace buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_kargs_, arg.p_workspace_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + auto CheckArgument(const Argument& arg, const StreamConfig& stream_config) const + { + bool all_have_kbatch_gt_one, all_have_main_k_block_loop; + + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1( + arg.gemm_kernel_args_[0].karg_.M, + arg.gemm_kernel_args_[0].karg_.MPadded, + arg.gemm_kernel_args_[0].karg_.K, + arg.gemm_kernel_args_[0].karg_.StrideA, + arg.gemm_kernel_args_[0].karg_.k_batch, + arg.gemm_kernel_args_[0].karg_.K0Padded, + arg.gemm_kernel_args_[0].karg_.KPadded); + + all_have_kbatch_gt_one = arg.K_BATCH > 1; + all_have_main_k_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + } + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + gemm_arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(gemm_arg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + const auto a_grid_desc_kbatch_ak0_m_ak1 = + GridwiseGemm::MakeAGridDescriptor_KBatch_K0_M_K1(gemm_arg.M, + gemm_arg.MPadded, + gemm_arg.K, + gemm_arg.StrideA, + gemm_arg.k_batch, + gemm_arg.K0Padded, + gemm_arg.KPadded); + + bool not_all_have_main_k_block_loop_same = + all_have_main_k_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop( + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)); + bool not_all_have_kbatch_value_same = + all_have_kbatch_gt_one xor (gemm_arg.k_batch > 1); + + if(not_all_have_main_k_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << gemm_arg.k_batch + << ", group [0], kbatch: " << gemm_arg.k_batch << " in " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + return std::make_tuple(all_have_kbatch_gt_one, all_have_main_k_block_loop); + } + + template + float DispatchKernel(const Argument& arg, + void* dev_gemm_kargs, + void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + const auto gemm_kernel = + kernel_grouped_gemm_xdl_splitk; + + const auto elementwise_kernel = kernel_elementwise, + CDDataTypes, + ck::Tuple, + Block2TileMap, + CDEElementwiseOperation>; + return LaunchKernel(gemm_kernel, + elementwise_kernel, + arg, + dev_gemm_kargs, + dev_gemm_workspace, + stream_config); + } + + template + float LaunchKernel(const KernelFunction& gemm_kernel, + const KernelFunction2& elementwise_kernel, + const Argument& arg, + void* dev_gemm_kargs, + [[maybe_unused]] void* dev_gemm_workspace, + const StreamConfig& stream_config) const + { + float time{0.f}; + + hip_check_error( + hipMemcpyAsync(dev_gemm_kargs, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + + auto preprocess = [&]() { + hip_check_error(hipMemsetAsync( + dev_gemm_workspace, 0, arg.GetWorkspaceSizeBytes(), stream_config.stream_id_)); + }; + + // GEMM kernel + time = launch_and_time_kernel_with_preprocess( + stream_config, + preprocess, + gemm_kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_kargs), + arg.gemm_kernel_args_.size(), + arg.a_element_op_, + arg.b_element_op_, + PassThrough{}); + + // Elementwise kernels + for(size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + time += launch_and_time_kernel( + stream_config, + elementwise_kernel, + dim3(arg.group_grid_size_[i]), + dim3(BlockSize), + 0, + concat_tuple(make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + arg.elementwise_d_grid_descs_m_n_[i]), + make_tuple(arg.elementwise_c_grid_descs_m_n_[i]), + concat_tuple(make_tuple(arg.gemm_kernel_args_[i].karg_.p_c_grid), + arg.ds_grid_pointer_[i]), + type_convert(arg.e_ptrs_[i]), + Block2TileMap{arg.elementwise_c_grid_descs_m_n_[i].GetLength(I0), + arg.elementwise_c_grid_descs_m_n_[i].GetLength(I1)}, + arg.cde_element_op_); + } + return time; + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& gemm_arg = arg.gemm_kernel_args_[i].karg_; + + bool group_arg_valid = GridwiseGemm::CheckValidity(gemm_arg); + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + gemm_arg.Print(); + } + } + supported = supported && group_arg_valid; + } + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_dev_gemm_kargs_ = p_dev_kernel_args; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + void SetWorkSpacePointer( + BaseArgument* p_arg, + void* p_workspace, + [[maybe_unused]] const StreamConfig& stream_config = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + p_arg_->UpdateEPointers(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } + + [[deprecated]] static void SetKBatchSize(Argument& arg, index_t kbatch) + { + arg.UpdateKBatch(kbatch); + } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp new file mode 100755 index 000000000000..61058dec2b25 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp @@ -0,0 +1,974 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/host_utility/stream_utility.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/// +/// @brief Entry point kernel for device-wide Grouped GEMM operation. +/// +/// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures. +/// @param[in] group_count The number of together processed GEMMs. +/// +/// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation. +/// @tparam GemmDesc The structure holding all necessary descriptors and +/// other data needed for grouped gemm calculation and work +/// distribution. +/// @tparam LocalBlock2ETileMap The structure providing mapping between workgroup ids, +/// the data tiles to process and the output tiles. +/// +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_multiple_d_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + __shared__ uint8_t p_shared1[shared_size]; + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + constexpr auto NumDTensor = DsDataType::Size(); + index_t tile_id = get_block_1d_id(); + index_t tile_offset = 0; + index_t group_id = -1; + index_t group_offset = 0; + index_t grid_size_grp = 0; + + index_t gemm_tile_id_start = 0; + index_t gemm_tile_id_end = 0; + + index_t M = 0, N = 0, K = 0; + + auto b2c_tile_map = OffsettedBlockToCTileMap(LocalBlock2ETileMap(1, 1), 1, 1); + + do + { + // Find corresponding GEMM group for our tile + while(!(tile_id >= gemm_tile_id_start && tile_id < gemm_tile_id_end) && + group_id < group_count) + { + group_offset += grid_size_grp; + group_id++; + + if(group_id >= group_count) + return; + + M = gemm_desc_ptr[group_id].M; + N = gemm_desc_ptr[group_id].N; + K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + { + grid_size_grp = 0; + continue; + } + + b2c_tile_map = + OffsettedBlockToCTileMap(LocalBlock2ETileMap(M, N, 4), group_offset, tile_offset); + grid_size_grp = b2c_tile_map.CalculateGridSize(M, N); + + gemm_tile_id_start = group_offset; + gemm_tile_id_end = group_offset + grid_size_grp; + } + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + DsGridPointer p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + static constexpr index_t kbatch = 1; + static constexpr index_t k_grain = kbatch * KPerBlock; + index_t K_split = (K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + // Update tile offset if we have moved within group + b2c_tile_map.UpdateTileOffset(tile_offset); + + using Problem = typename GridwiseGemm::Problem; + auto problem = Problem(gemm_desc_ptr[group_id].M, + gemm_desc_ptr[group_id].N, + gemm_desc_ptr[group_id].K, + gemm_desc_ptr[group_id].StrideA, + gemm_desc_ptr[group_id].StrideB, + gemm_desc_ptr[group_id].StrideDs, + gemm_desc_ptr[group_id].StrideE, + kbatch); + + if(has_main_k_block_loop) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Three) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Four) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Five) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Seven) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + else + { + GridwiseGemm::template Run_2Lds( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + static_cast(p_shared1), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + GridwiseGemm::template Run( + static_cast(gemm_desc_ptr[group_id].p_a_grid), + static_cast(gemm_desc_ptr[group_id].p_b_grid), + p_ds_grid, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + a_element_op, + b_element_op, + cde_element_op, + b2c_tile_map); + } + } + + tile_id += get_grid_size(); + tile_offset += get_grid_size(); + + } while(group_id < group_count); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; +#endif // end of if (defined(__gfx9__)) +} + +template + +struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop + : public DeviceGroupedGemmTileLoop +{ + using DeviceOp = DeviceGroupedGemmMultipleDXdlCShuffleTileLoop; + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using KernelArguments = GroupedGemmKernelArgument; + using Block2ETileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + using OffsettedLocalBlock2ETileMap = OffsettedBlockToCTileMap2; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& /* p_As */, + std::vector& /* p_Bs */, + std::vector>& /* p_Ds */, + std::vector& /* p_Es */, + const std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + int occupancy_num_blocks, + int gpu_cu_count) + : group_count_{static_cast(gemm_descs.size())}, + occupancy_num_blocks_{occupancy_num_blocks}, + gpu_cu_count_{gpu_cu_count}, + gemm_descs_{gemm_descs}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + tile_count_{0} + { + for(const auto& desc : gemm_descs) + { + const auto M = desc.M_; + const auto N = desc.N_; + const auto b2c_tile_map = Block2ETileMap(M, N); + tile_count_ += b2c_tile_map.CalculateGridSize(M, N); + } + } + + index_t group_count_; + const void* p_dev_gemm_args_; + int occupancy_num_blocks_; + int gpu_cu_count_; + const std::vector& gemm_descs_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + index_t tile_count_; + }; + + struct KernelConfig + { + // The oversubscription factor for the number of blocks that can simultaneously reside on + // GPU. + static constexpr int BLOCK_SUBSCRIPTION_FACTOR = 1; + static constexpr int BLOCK_WAVES = BlockSize / get_warp_size(); + static constexpr int CU_SIMDS = 4; + // Assume we want to have at most 2 waves per SIMD + static constexpr int CU_BLOCKS = math::integer_divide_floor(2 * CU_SIMDS, BLOCK_WAVES); + }; + + // Invoker + struct Invoker : public BaseInvoker + { + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using user provided device buffer for kernel + /// arguments. + /// + /// @param[in] arg The structure containing kernel arguments (in host + /// memory). + /// @param[in] dev_gemm_args The pointer to device memory with kernel arguments. + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config = StreamConfig{}) + { + if(dev_gemm_args == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + float ave_time = 0; + ave_time = DispatchKernel(arg, dev_gemm_args, stream_config); + + return ave_time; + } + + /// + /// @brief Launch Grouped Gemm kernel. + /// + /// @note This function overload is using device buffers (for kernel arguments and + /// for kernel auxiliary workspace) provided with an argument. The user should + /// call @see GetDeviceKernelArgSize, and @see SetDeviceKernelArgs, on arg + /// parameter to properly allocate those buffers. + /// + /// @param[in] arg The structure containing kernel arguments (in host memory). + /// @param[in] stream_config The device stream configuration. + /// + /// @return The average kernel execution time (if time measurement is enabled.) + /// + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_dev_gemm_args_ == nullptr) + { + std::ostringstream err; + err << "The gemm arguments device buffer is not allocated!" + << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + return Run(arg, arg.p_dev_gemm_args_, stream_config); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + + private: + float DispatchKernel(const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config) const + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + return LaunchKernel(kernel, arg, dev_gemm_args, stream_config); + } + + template + int CalculateMaxOccupancyGridSize(const KernelFunction& kernel, + const StreamConfig& stream_config) const + { + // Calculate max number of workgroups that can simultaneously reside on the CU. + int occ_num_blocks = 0; + size_t dyn_shared_mem_per_blk = 0; + hip_check_error(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &occ_num_blocks, kernel, BlockSize, dyn_shared_mem_per_blk)); + + int cu_count = getAvailableComputeUnitCount(stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "MaxActiveBlocksPerCU: " << occ_num_blocks + << ", available CUs count: " << cu_count << ", occup. grid size: " + << ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS) * cu_count + << std::endl; + } + + return cu_count * ck::math::min(occ_num_blocks, KernelConfig::CU_BLOCKS); + } + + template + float LaunchKernel(const KernelFunction& kernel, + const Argument& arg, + const void* dev_gemm_args, + const StreamConfig& stream_config) const + { + int grid_size = CalculateMaxOccupancyGridSize(kernel, stream_config); + + if(stream_config.log_level_ > 0) + { + std::cout << "grid_size: " << grid_size << " tile_count: " << arg.tile_count_ + << std::endl; + } + + // run multiple kernels + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(dev_gemm_args), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + bool supported = true; + + constexpr index_t k_batch = 1; + for(index_t i = 0; i < arg.group_count_; ++i) + { + std::array placeholder_p_ds_grid{}; + std::array stride_Ds; + std::copy_n(arg.gemm_descs_[i].stride_Ds_.begin(), NumDTensor, stride_Ds.begin()); + using GridArg = typename GridwiseGemm::Argument; + GridArg gridwise_arg(nullptr, // p_a_grid, + nullptr, // p_b_grid, + placeholder_p_ds_grid, // p_ds_grid, + nullptr, // p_e_grid , + arg.gemm_descs_[i].M_, + arg.gemm_descs_[i].N_, + arg.gemm_descs_[i].K_, + arg.gemm_descs_[i].stride_A_, + arg.gemm_descs_[i].stride_B_, + stride_Ds, + arg.gemm_descs_[i].stride_C_, + k_batch, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_); + + if((arg.gemm_descs_[i].K_ % AK1 != 0 || arg.gemm_descs_[i].K_ % BK1 != 0) && + !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + supported = supported && GridwiseGemm::CheckValidity(gridwise_arg); + } + + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + int occupancy, num_cu; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + + return Argument{p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu}; + } + + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_elementwise_op, + BElementwiseOperation b_elementwise_op, + CDEElementwiseOperation cde_elementwise_op) override + { + const auto kernel = kernel_grouped_gemm_multiple_d_xdl; + int occupancy, num_cu; + hip_check_error( + hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, BlockSize, 0)); + + hipDeviceProp_t dev_prop; + hipDevice_t dev; + hip_check_error(hipGetDevice(&dev)); + hip_check_error(hipGetDeviceProperties(&dev_prop, dev)); + num_cu = dev_prop.multiProcessorCount; + + return std::make_unique(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_elementwise_op, + b_elementwise_op, + cde_elementwise_op, + occupancy, + num_cu); + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::ostringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemmMultipleDXdlCShuffleTileLoop" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "BlkGemmPipelineScheduler: " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << "BlkGemmPipelineVersion: " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + void SetDeviceKernelArgs(Argument& arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + hip_check_error(hipMemcpyAsync(p_dev_kernel_args, + p_host_kernel_args, + GetDeviceKernelArgSize(&arg), + hipMemcpyHostToDevice)); + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, + void* p_dev_kernel_args, + const void* p_host_kernel_args) const override + { + return SetDeviceKernelArgs( + *dynamic_cast(p_arg), p_dev_kernel_args, p_host_kernel_args); + } + + void SetDeviceKernelArgs(Argument& arg, void* p_dev_kernel_args) const + { + arg.p_dev_gemm_args_ = p_dev_kernel_args; + } + + virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return SetDeviceKernelArgs(*dynamic_cast(p_arg), p_dev_kernel_args); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(KernelArguments); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp new file mode 100755 index 000000000000..3fb2c5ae86d6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp @@ -0,0 +1,892 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1( + const void CK_CONSTANT_ADDRESS_SPACE* group_kernel_args, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const B1ElementwiseOperation b1_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto arg_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(group_kernel_args)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + + while( + (!(block_id >= arg_ptr[group_id].block_start_ && block_id < arg_ptr[group_id].block_end_))) + { + if(block_id < arg_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + // per-group batch offset + const index_t num_blocks_per_batch = arg_ptr[group_id].num_blocks_per_batch_; + const index_t g_idx = __builtin_amdgcn_readfirstlane( + (block_id - arg_ptr[group_id].block_start_) / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetABasePtr(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetBBasePtr(g_idx))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + arg_ptr[group_id].compute_base_ptr_of_batch_.GetB1BasePtr(g_idx))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(arg_ptr[group_id].compute_base_ptr_of_batch_.GetCBasePtr(g_idx))); + + GridwiseGemm::template Run( + arg_ptr[group_id].p_a_grid_ + a_batch_offset, + arg_ptr[group_id].p_b_grid_ + b_batch_offset, + arg_ptr[group_id].p_b1_grid_ + b1_batch_offset, + arg_ptr[group_id].p_c_grid_ + c_batch_offset, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op, + arg_ptr[group_id].a_grid_desc_ak0_m_ak1_, + arg_ptr[group_id].b_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].b1_grid_desc_bk0_n_bk1_, + arg_ptr[group_id].c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg_ptr[group_id].block_2_ctile_map_, + arg_ptr[group_id].c0_matrix_mask_); +#else + ignore = group_kernel_args; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = b1_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx9__)) +} + +// Computes C = A * B0 * B1 +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle + : public DeviceGroupedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0 && NumDimO > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + +#if 0 + // TODO ANT: use alias + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimN; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimO; + static constexpr index_t NumDimGemm1K = NumDimN; +#endif + + using DeviceOp = DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle; + using ProblemDesc = typename DeviceGroupedGemmSoftmaxGemmPermute::ProblemDesc; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm< + Sequence, + Sequence, + GemmSpec, + ASpec, + BSpec, + B1Spec, + CSpec>; + + static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + + static auto MakeBGridDescriptor_BK0_N_BK1(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b_gs_ns_ks_lengths_vec, b_gs_ns_ks_strides_vec), + Number{}); + } + + static auto + MakeB1GridDescriptor_BK0_N_BK1(const std::vector& b1_gs_gemm1ns_gemm1ks_lengths_vec, + const std::vector& b1_gs_gemm1ns_gemm1ks_strides_vec) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_gemm1ns_gemm1ks_lengths_vec, + b1_gs_gemm1ns_gemm1ks_strides_vec), + Number{}); + } + + using AGridDesc_AK0_M_AK1 = decltype(MakeAGridDescriptor_AK0_M_AK1({}, {})); + using BGridDesc_BK0_N_BK1 = decltype(MakeBGridDescriptor_BK0_N_BK1({}, {})); + using B1GridDesc_BK0_N_BK1 = decltype(MakeB1GridDescriptor_BK0_N_BK1({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using BGridDesc_G_N_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_K = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const BGridDesc_G_N_K& b_grid_desc_g_n_k, + const B1GridDesc_G_N_K& b1_grid_desc_g_n_k, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b_grid_desc_g_n_k_(b_grid_desc_g_n_k), + b1_grid_desc_g_n_k_(b1_grid_desc_g_n_k), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetBBasePtr(index_t g_idx) const + { + return b_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + BGridDesc_G_N_K b_grid_desc_g_n_k_; + B1GridDesc_G_N_K b1_grid_desc_g_n_k_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< + ADataType, // TODO: distinguish A/B datatype + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_AK0_M_AK1, + BGridDesc_BK0_N_BK1, + B1GridDesc_BK0_N_BK1, + CGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + AK1, + BK1, + B1K1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + Gemm1NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + true, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + true, + BBlockLdsExtraN, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + false, + B1BlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + LoopSched, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle>; + + using Block2CTileMap = OffsettedBlockToCTileMap; + + struct GroupKernelArg + { + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1_; + typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // batch & stride + index_t num_blocks_per_batch_; + ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // block-to-c-tile map + Block2CTileMap block_2_ctile_map_; + + index_t block_start_, block_end_; + }; + + struct GroupDeviceArg + { + // lengths for the last dimensions of overall problem for sanity check of vector load/store + std::vector raw_lengths_mz_nz_kz_gemm1nz_; + + // strides for the last dimensions of each tensor for sanity check of vector load/store + std::vector a_mz_kz_strides_; + std::vector b_nz_kz_strides_; + std::vector b1_nz_kz_strides_; + std::vector c_mz_gemm1nz_strides_; + + // for gridwise gemm check + CGridDesc_M_N c_grid_desc_m_n_; + }; + + // Argument + // FIXME: constness + struct Argument : public BaseArgument + { + Argument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op} + { + // TODO ANT: implement bias addition + group_count_ = problem_desc_vec.size(); + + if(!(group_count_ == p_a_vec.size() && group_count_ == p_b_vec.size() && + group_count_ == p_b1_vec.size() && group_count_ == p_c_vec.size())) + { + throw std::runtime_error("wrong! group_count_ != a/b/b1/c_vec.size"); + } + + if(!(p_acc0_biases_vec.size() == p_acc1_biases_vec.size())) + { + throw std::runtime_error("wrong! acc0_bias_vec.size != acc1_bias_vec.size"); + } + + grid_size_ = 0; + + for(std::size_t i = 0; i < group_count_; i++) + { + const auto p_a_grid = static_cast(p_a_vec[i]); + const auto p_b_grid = static_cast(p_b_vec[i]); + const auto p_b1_grid = static_cast(p_b1_vec[i]); + const auto p_c_grid = static_cast(p_c_vec[i]); + + const auto& problem_desc = problem_desc_vec[i]; + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); + const auto b1_grid_desc_bk0_n_bk1 = MakeB1GridDescriptor_BK0_N_BK1( + problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = Transform::MakeCGridDescriptor_M_N( + problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); + + const auto a_grid_desc_g_m_k = Transform::MakeAGridDescriptor_G_M_K( + problem_desc.a_gs_ms_ks_lengths, problem_desc.a_gs_ms_ks_strides); + const auto b_grid_desc_g_n_k = Transform::MakeB0GridDescriptor_G_N_K( + problem_desc.b0_gs_ns_ks_lengths, problem_desc.b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_k = Transform::MakeB1GridDescriptor_G_N_K( + problem_desc.b1_gs_os_ns_lengths, problem_desc.b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = Transform::MakeCGridDescriptor_G_M_N( + problem_desc.c_gs_ms_os_lengths, problem_desc.c_gs_ms_os_strides); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const auto block_2_ctile_map = Block2CTileMap(c_grid_desc_m_n, BlockStart); + const index_t batch_count = c_grid_desc_g_m_n.GetLength(I0); + const index_t grid_size_grp = + block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n) * batch_count; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + // batch stride + const auto compute_base_ptr_of_batch = ComputeBasePtrOfStridedBatch( + a_grid_desc_g_m_k, b_grid_desc_g_n_k, b1_grid_desc_g_n_k, c_grid_desc_g_m_n); + + // C0 mask + const auto c0_matrix_mask = C0MatrixMask(b_grid_desc_g_n_k.GetLength(I1)); + + grid_size_ += grid_size_grp; + + // for each group, make sure acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias and + // so on + if(!(problem_desc.acc0_biases_gs_ms_ns_lengths.size() == NumAcc0Bias && + problem_desc.acc0_biases_gs_ms_ns_strides.size() == NumAcc0Bias && + problem_desc.acc1_biases_gs_ms_os_lengths.size() == NumAcc1Bias && + problem_desc.acc1_biases_gs_ms_os_strides.size() == NumAcc1Bias)) + { + throw std::runtime_error( + "wrong! number of biases in function argument does not " + "match that in template argument"); + } + + group_kernel_args_.push_back({p_a_grid, + p_b_grid, + p_b1_grid, + p_c_grid, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b1_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_ctile_map.CalculateGridSize(c_grid_desc_m_n), + compute_base_ptr_of_batch, + c0_matrix_mask, + block_2_ctile_map, + BlockStart, + BlockEnd}); + + group_device_args_.push_back( + {{problem_desc.a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN - 1], + problem_desc.b0_gs_ns_ks_lengths[NumDimG + NumDimN + NumDimK - 1], + problem_desc.b1_gs_os_ns_lengths[NumDimG + NumDimO - 1]}, + {problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + problem_desc.a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + {problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN - 1], + problem_desc.b0_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]}, + {problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO - 1], + problem_desc.b1_gs_os_ns_strides[NumDimG + NumDimO + NumDimN - 1]}, + {problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM - 1], + problem_desc.c_gs_ms_os_strides[NumDimG + NumDimM + NumDimO - 1]}, + c_grid_desc_m_n}); + } + } + + std::vector group_kernel_args_; + std::vector group_device_args_; + + std::size_t group_count_; + index_t grid_size_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!DeviceOp::IsSupportedArgument(arg)) + { + throw std::runtime_error("wrong! unsupported argument"); + } + + bool all_has_main_k_block_loop = true; + bool some_has_main_k_block_loop = false; + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto K = arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.group_kernel_args_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); + const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + all_has_main_k_block_loop &= y; + some_has_main_k_block_loop |= y; + } + + hipGetErrorString( + hipMemcpyWithStream(arg.p_workspace_, + arg.group_kernel_args_.data(), + arg.group_kernel_args_.size() * sizeof(GroupKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = + kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.group_count_, + arg.a_element_op_, + arg.b_element_op_, + arg.acc_element_op_, + arg.b1_element_op_, + arg.c_element_op_); + }; + + // Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need + // to concern Gemm0's loop + if(all_has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else if(!some_has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + throw std::runtime_error("wrong! all gemm problems have to simultaneously meet " + "has_main_k_block_loop or no_main_k_block_loop"); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + // TODO ANT: Check if tensor specialization & strides mismatch + + bool all_has_main_k_block_loop = true; + bool some_has_main_k_block_loop = false; + + for(std::size_t i = 0; i < arg.group_count_; i++) + { + const auto& kernel_arg = arg.group_kernel_args_[i]; + const auto& device_arg = arg.group_device_args_[i]; + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_m = device_arg.c_grid_desc_m_n_.GetLength(I0); + const index_t c_gemm1n = device_arg.c_grid_desc_m_n_.GetLength(I1); + const index_t a_m = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I1); + const index_t b1_gemm1n = kernel_arg.b1_grid_desc_bk0_n_bk1_.GetLength(I1); + if(!(c_m == a_m && c_gemm1n == b1_gemm1n)) + { + return false; + } + + // Check if having main loop + const auto K = kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * + kernel_arg.a_grid_desc_ak0_m_ak1_.GetLength(I2); + const bool y = GridwiseGemm::CalculateHasMainKBlockLoop(K); + all_has_main_k_block_loop &= y; + some_has_main_k_block_loop |= y; + + // Note: we need raw lengths since threadwise copy can not handle vector load when + // part of vector is out of bounds + const auto MzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[0]; + const auto NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[1]; + const auto KzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[2]; + const auto Gemm1NzRaw = device_arg.raw_lengths_mz_nz_kz_gemm1nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b_extent_lowest = BBlockTransferSrcVectorDim == 2 ? KzRaw : NzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? NzRaw : Gemm1NzRaw; + const auto c_extent_lowest = Gemm1NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b_extent_lowest % BBlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = ABlockTransferSrcVectorDim == 2 + ? device_arg.a_mz_kz_strides_[1] + : device_arg.a_mz_kz_strides_[0]; + const auto b_stride_lowest = BBlockTransferSrcVectorDim == 2 + ? device_arg.b_nz_kz_strides_[1] + : device_arg.b_nz_kz_strides_[0]; + const auto b1_stride_lowest = B1BlockTransferSrcVectorDim == 2 + ? device_arg.b1_nz_kz_strides_[1] + : device_arg.b1_nz_kz_strides_[0]; + const auto c_stride_lowest = + device_arg.c_mz_gemm1nz_strides_[1]; // cshuffle assumes lowest dim in Gemm1Ns to be + // contiguous + + if(!(a_stride_lowest == 1 || b_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(kernel_arg.a_grid_desc_ak0_m_ak1_, + kernel_arg.b_grid_desc_bk0_n_bk1_, + kernel_arg.b1_grid_desc_bk0_n_bk1_, + device_arg.c_grid_desc_m_n_, + kernel_arg.block_2_ctile_map_)) + { + return false; + } + } + + // all gemm problems have to simultaneously meet has_main_k_block_loop or + // no_main_k_block_loop + if(!(all_has_main_k_block_loop || !some_has_main_k_block_loop)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a_vec, + p_b_vec, + p_b1_vec, + p_c_vec, + p_acc0_biases_vec, + p_acc1_biases_vec, + problem_desc_vec, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector p_a_vec, + std::vector p_b_vec, + std::vector p_b1_vec, + std::vector p_c_vec, + std::vector> p_acc0_biases_vec, + std::vector> p_acc1_biases_vec, + std::vector problem_desc_vec, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(p_a_vec, + p_b_vec, + p_b1_vec, + p_c_vec, + p_acc0_biases_vec, + p_acc1_biases_vec, + problem_desc_vec, + a_element_op, + b_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << Gemm1NPerBlock << ", " + << Gemm1KPerBlock << ", " + << B1K1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(BSpec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + return dynamic_cast(p_arg)->group_count_ * sizeof(GroupKernelArg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp new file mode 100755 index 000000000000..cbee4e09f407 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -0,0 +1,795 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].BlockStart_ && + block_id < gemm_desc_ptr[group_id].BlockEnd_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].a_ptr_, + gemm_desc_ptr[group_id].b_ptr_, + gemm_desc_ptr[group_id].ds_ptr_, + gemm_desc_ptr[group_id].e_ptr_, + p_shared, + a_element_op, + b_element_op, + c_element_op, + gemm_desc_ptr[group_id].a_grid_desc_ak0_m_ak1_, + gemm_desc_ptr[group_id].b_grid_desc_bk0_n_bk1_, + gemm_desc_ptr[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + gemm_desc_ptr[group_id].block_2_etile_map_); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm +{ + using DeviceOp = DeviceGroupedGemm_Xdl; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + static auto MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return DeviceOp::MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + using ComputeDataType = ADataType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + ComputeDataType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumPrefetch, // NumGemmKPrefetchStage + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t< + decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + DsGridDesc_M_N{}))>; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + struct GroupedGemmBlock2ETileMap + { + using Block2ETileMap = + remove_cvref_t; + + GroupedGemmBlock2ETileMap() + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}); + BlockStart_ = -1; + } + + GroupedGemmBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, ck::index_t BlockStart) + { + block_2_etile_map_ = GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + BlockStart_ = BlockStart; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_2_etile_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] - BlockStart_)); + } + + // it's actually E-Tile + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_2_etile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + __host__ bool CheckValidity(const EGridDesc_M_N& e_grid_desc_m_n) const + { + return block_2_etile_map_.CheckValidity(e_grid_desc_m_n); + } + + Block2ETileMap block_2_etile_map_; + ck::index_t BlockStart_; + }; + + struct GemmBiasTransKernelArg + { + // pointers + const ADataType* a_ptr_; + const BDataType* b_ptr_; + typename GridwiseGemm::DsGridPointer ds_ptr_; + EDataType* e_ptr_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + GroupedGemmBlock2ETileMap block_2_etile_map_; + ck::index_t BlockStart_, BlockEnd_; + }; + + // Argument + struct Argument : public BaseArgument + { + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_desc_kernel_arg_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + a_mtx_mraw_kraw_.emplace_back(M, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideC = gemm_descs[i].stride_C_; + + // pointer + typename GridwiseGemm::DsGridPointer p_ds_grid{}; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DDataType = remove_cvref_t>; + + p_ds_grid(j) = static_cast(p_Ds[i][j]); + }); + + // tensor descriptors for problem definiton + const auto a_grid_desc_m_k = DeviceOp::MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_n_k = DeviceOp::MakeBGridDescriptor_N_K(K, N, StrideB); + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = DeviceOp::MakeEGridDescriptor_M_N( + M, N, gemm_descs[i].stride_Ds_[j]); + }); + + const auto e_grid_desc_m_n = + DeviceOp::MakeEGridDescriptor_M_N(M, N, StrideC); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = + GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = + GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + const index_t grid_size_grp = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, 0) + .block_2_etile_map_.CalculateGridSize(e_grid_desc_m_n); + + const index_t BlockStart = grid_size_; + const index_t BlockEnd = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(e_grid_desc_m_n, BlockStart); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map)) + { + // tensor descriptors for block/thread-wise copy + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n); + + gemm_desc_kernel_arg_.push_back( + GemmBiasTransKernelArg{static_cast(p_As[i]), + static_cast(p_Bs[i]), + p_ds_grid, + static_cast(p_Es[i]), + a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map, + BlockStart, + BlockEnd}); + } + } + } + + // private: + index_t group_count_; + index_t skipped_group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + index_t grid_size_; + void* gemm_kernel_host_args_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "group: " << i << " arg.a_grid_desc_ak0_m_ak1_{" + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) + << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I1) + << ", " + << arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2) + << "}"; + + std::cout << ", arg.b_grid_desc_bk0_n_bk1_{" + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I0) + << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I1) + << ", " + << arg.gemm_desc_kernel_arg_[i].b_grid_desc_bk0_n_bk1_.GetLength(I2) + << "}"; + + std::cout << ", arg.e_grid_desc_m_n_{ " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I0) << ", " + << arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_.GetLength(I1) << "}" + << std::endl; + } + + if(!GridwiseGemm::CheckValidity(arg.gemm_desc_kernel_arg_[i].a_grid_desc_m_k_, + arg.gemm_desc_kernel_arg_[i].b_grid_desc_n_k_, + arg.gemm_desc_kernel_arg_[i].ds_grid_desc_m_n_, + arg.gemm_desc_kernel_arg_[i].e_grid_desc_m_n_, + arg.gemm_desc_kernel_arg_[i].block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + const auto K = arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I0) * + arg.gemm_desc_kernel_arg_[i].a_grid_desc_ak0_m_ak1_.GetLength(I2); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) + { + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + hipGetErrorString(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmBiasTransKernelArg), + hipMemcpyHostToDevice, + cpy_stream)); + hipGetErrorString(hipEventRecord(cpy_event, cpy_stream)); + hipGetErrorString(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. + { + hipGetErrorString(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * + sizeof(GemmBiasTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_) { + const auto kernel = kernel_grouped_gemm_xdl; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size(), + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}); + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_desc_kernel_arg_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by vector + // load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, + // thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->group_count_ * sizeof(GemmBiasTransKernelArg); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp new file mode 100755 index 000000000000..8fe71fb9a2d6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp @@ -0,0 +1,987 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + if constexpr(Zeroing) + { + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { + + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + + id_off += grid_size_grp; + id_local += grid_size_grp; + } +#else + ignore = gemm_descs_const; + ignore = barrier_count; + ignore = barrier_size_grp; + ignore = group_count; + ignore = grid_size_grp; + ignore = KBatch; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Xdl_Fixed_NK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using AComputeType = ComputeType; + using BComputeType = ComputeType; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< + ADataType, // TODO: distinguish A/B datatype + BDataType, + AComputeType, + BComputeType, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + NumPrefetch, // NumGemmKPrefetchStage + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched, + PipelineVer, + ALDSType, + BLDSType>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + // TODO: replace with GroupedGemmKernelArgument + struct GemmBiasTransKernelArg + { + // pointers + const void* a_ptr_; + const void* b_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + index_t StrideA_, StrideB_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_; + const index_t N = gemm_desc_kernel_arg_[0].N_; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + } + + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + // pointer + std::array p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideDs; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + // check block-to-E-tile + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + if(!GridwiseGemm:: + template CheckValidity( + AverM, N, K, StrideA, StrideB, StrideDs, StrideE, 1)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + nullptr, + nullptr, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + const auto KPad = + GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + if(arg.k_batch_ == 1) + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + false, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + nullptr, + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_fixed_nk, + GemmSpec, + true, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + // For bf16 datatype only kbatch = 1 scenario is supported. This condition is + // enforced in IsSupportedArgument function + if constexpr(std::is_same::value) + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by + // vector load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} + // layout, thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + // For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd + // instruction that supports bf16 and we cannot use splitk because of that + if constexpr(std::is_same::value) + { + supported = supported & (arg.k_batch_ == 1); + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Xdl_Fixed_NK" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->grouped_gemm_kernel_args_dev = kernel_args; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_workspace_ = p_workspace; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_)); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(k_batch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(kbatch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp new file mode 100755 index 000000000000..01f52881f4d5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + __shared__ uint8_t p_shared[shared_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + GridwiseGemm::template Run( + gemm_desc_ptr[group_id].karg_, + static_cast(p_shared), + gemm_desc_ptr[group_id].block_2_ctile_map_, + a_element_op, + b_element_op, + c_element_op); +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx9__)) +} + +template > && + is_same_v>, + bool> = false> +struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(KPerBlock % AK1 == 0); + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< + BlockSize, + ADataType, + BDataType, + AccDataType, + EDataType, + ALayout, + BLayout, + ELayout, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + NumGemmKPrefetchStage, + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + AK1, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + false, // AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + false, // BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferScalarPerVector_NPerBlock, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + LoopSched, + PipelineVer>; + + using CGridDesc_M_N = typename GridwiseGemm::CGridDesc_M_N; + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArgument = typename GridwiseGemm::Argument; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + struct GemmTransKernelArg + { + KernelArgument karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArg() = default; + GemmTransKernelArg(KernelArgument&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + + static constexpr index_t DefaultKBatch = 1; + + // Argument + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs) + : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs, + index_t kbatch) + : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr} + { + grid_size_ = 0; + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_kernel_args_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_c = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t k_padded = GridwiseGemm::CalculateKPadded(K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(K, K_BATCH); + + const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(M, N, stride_c); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument{type_convert(p_As[i]), + type_convert(p_Bs[i]), + type_convert(p_Es[i]), + M, + N, + K, + stride_a, + stride_b, + stride_c, + m_padded, + n_padded, + k_padded, + k0_padded, + K_BATCH}; + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + } + } + + /** + * @brief Recalculate group grid size for all gemms and update B2C maps. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t k0_padded = GridwiseGemm::CalculateK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + karg.KPadded = k_padded; + karg.K0Padded = k0_padded; + karg.k_batch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + } + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + + std::vector gemm_kernel_args_; + void* gemm_kernel_host_args_; + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) + { + index_t K0 = arg.gemm_kernel_args_[0].karg_.K0Padded; + bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.k_batch > 1; + bool all_have_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& karg = arg.gemm_kernel_args_[i].karg_; + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.k_batch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + K0 = karg.K0Padded; + bool not_all_have_main_k0_block_loop_same = + all_have_main_k0_block_loop xor GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + bool not_all_have_kbatch_value_same = all_have_kbatch_gt_one xor (kbatch > 1); + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " + << "group [" << i << "], kbatch: " << kbatch + << ", group [0], kbatch: " << arg.gemm_kernel_args_[0].karg_.k_batch + << " in " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + } + + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) + { + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + hip_check_error(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + cpy_stream)); + hip_check_error(hipEventRecord(cpy_event, cpy_stream)); + hip_check_error(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. + { + + hip_check_error( + hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) + { + for(const auto& trans_arg : arg.gemm_kernel_args_) + { + const auto& karg = trans_arg.karg_; + hip_check_error(hipMemsetAsync(karg.p_c_grid, + 0, + karg.M * karg.N * sizeof(EDataType), + stream_config.stream_id_)); + } + } + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_kernel_args_.size(), + PassThrough{}, + PassThrough{}, + PassThrough{}); + }; + + if(all_have_main_k0_block_loop) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + } + else + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_xdl_splitk; + + Run(kernel); + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; + } + + if(std::is_same_v && arg.K_BATCH > 1 && !is_bf16_atomic_supported()) + { + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& a = arg.gemm_kernel_args_[i].karg_; + + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); + } + } + supported = supported && group_arg_valid; + } + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) + { + return Argument{p_As, p_Bs, p_Es, gemm_descs}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) override + { + return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_XdlSplitK" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerXDL << ", " + << NPerXDL << ", " + << MXdlPerWave << ", " + << NXdlPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMXdlPerWavePerShuffle << ", " + << CShuffleNXdlPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + // TODO: deperecation notice. + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + // polymorphic + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_kernel_args_.begin(), + pArg_->gemm_kernel_args_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp new file mode 100755 index 000000000000..67a100a1124d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp @@ -0,0 +1,1255 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = QueryGroupNumber; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB0BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane(static_cast( + compute_base_ptr_of_batch.GetB1BasePtr(g_idx * QueryGroupNumber / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = alpha; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceGroupedQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceGroupedQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(arg.G1_ % QueryGroupNumber != 0) + { + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_grouped_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceGroupedQueryAttentionForward_Wmma, " + << "QueryGroupNumber: " + << QueryGroupNumber << ", " + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp new file mode 100755 index 000000000000..1ad37058db10 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp @@ -0,0 +1,408 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/conv_tensor_rearrange_op.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/io.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Image to column: +// input : input image [G, N, Di, Hi, Wi, C] +// output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C] +// input : input image [N, Di, Hi, Wi, G, C] +// output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C] +template = 1 && NDimSpatial <= 3, bool>::type = false> +struct DeviceImageToColumnImpl + : public DeviceConvTensorRearrange +{ + static constexpr bool is_NSpatialGC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool is_GNSpatialC = + std::is_same_v || + std::is_same_v || + std::is_same_v; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ConvToGemmFwdTransformer = + TransformConvFwdToGemm; + + static constexpr auto matrix_padder = + MatrixPadder{ + MPerBlock, 0 /* NPerBlock*/, KPerBlock}; + + // Use MakeADescriptor_M_K from grouped convolution forward + static auto + MakeInputDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + std::array a_g_n_c_wis_lengths{1}; + std::array b_g_k_c_xs_lengths{1}; + std::array c_g_n_k_wos_lengths{1}; + + auto copy = [](const auto& x, auto& y, index_t dst_offset) { + std::copy(x.begin(), x.end(), y.begin() + dst_offset); + }; + + constexpr index_t spatial_offset = 3; + + copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset); + copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset); + copy(output_spatial_lengths, c_g_n_k_wos_lengths, spatial_offset); + + // fill only significant values (C and N) + a_g_n_c_wis_lengths[I1] = N; + a_g_n_c_wis_lengths[I2] = C; + b_g_k_c_xs_lengths[I2] = C; + c_g_n_k_wos_lengths[I1] = N; + + ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths, + image_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + {}, // not needed for A Descriptor + c_g_n_k_wos_lengths, + {}, // not needed for A Descriptor + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + + const auto in_gemmmraw_gemmkraw_desc = + conv_to_gemm_transformer.template MakeADescriptor_M_K(); + + const auto in_gemmm_gemmk_desc = + matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc); + return in_gemmm_gemmk_desc; + } + + static auto + MakeOutDescriptor_M_K(const ck::index_t N, + const ck::index_t C, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& gemm_g_m_k_strides) + { + const index_t NDoHoWo = + N * ck::accumulate_n( + output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + const index_t CZYX = + C * ck::accumulate_n( + filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>()); + + const auto desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2])); + return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw); + } + + using InputGridDesc = + remove_cvref_t; + using OutputGridDesc = remove_cvref_t; + + using Block2ETileMap = remove_cvref_t< + decltype(BlockToCTileMap_M00_N0_M01Adapt( + OutputGridDesc{}))>; + + using GridwiseTensorRearrangeKernel = GridwiseTensorRearrange>; + + struct Argument : public BaseArgument + { + Argument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + : G_(G), + C_(C), + X_(filter_spatial_lengths[NDimSpatial - I1]), + p_in_{static_cast(p_in)}, + p_out_{static_cast(p_out)}, + image_g_n_c_wis_strides_{image_g_n_c_wis_strides}, + conv_filter_strides_{conv_filter_strides}, + conv_filter_dilations_{conv_filter_dilations}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads} + { + + in_grid_desc_m_k_ = MakeInputDescriptor_M_K(N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + + out_grid_desc_m_k_ = MakeOutDescriptor_M_K( + N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides); + + compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0]; + compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0]; + } + + void Print() const + { + std::cout << in_grid_desc_m_k_ << std::endl; + std::cout << out_grid_desc_m_k_ << std::endl; + } + + const ck::index_t G_; + const ck::index_t C_; + const ck::index_t X_; + + const InputDataType* p_in_; + OutputDataType* p_out_; + + const std::array& image_g_n_c_wis_strides_; + const std::array& conv_filter_strides_; + const std::array& conv_filter_dilations_; + const std::array& input_left_pads_; + const std::array& input_right_pads_; + + InputGridDesc in_grid_desc_m_k_; + OutputGridDesc out_grid_desc_m_k_; + + ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + const auto block_2_tile_map = + BlockToCTileMap_M00_N0_M01Adapt( + arg.out_grid_desc_m_k_); + const index_t grid_size = + block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_; + const auto kernel = kernel_tensor_rearrange, + GridwiseTensorRearrangeKernel>; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k_, + arg.p_in_, + arg.out_grid_desc_m_k_, + arg.p_out_, + arg.G_, + block_2_tile_map, + arg.compute_ptr_offset_of_batch_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const Argument& arg) + { + if constexpr(!(is_NSpatialGC || is_GNSpatialC)) + { + return false; + } + + const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1]; + const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1]; + const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1]; + const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1]; + bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_; + bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1; + + // check vector acces with c not packed + if(!is_c_packed && ScalarPerVector != 1) + return false; + // check vector access of filter window row (only C if C is not packed) + if(!is_w_packed && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of filter window row (X * C) + if(arg.X_ * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of pads (w_pad_left/w_pad_right * C) + if(w_pad_left * arg.C_ % ScalarPerVector != 0 || + w_pad_right * arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with stride and pad + if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + // check vector access of with dilation + if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0) + return false; + + return GridwiseTensorRearrangeKernel::CheckValidity(arg.in_grid_desc_m_k_, + arg.out_grid_desc_m_k_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) + { + return Argument{static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr + MakeArgumentPointer(const void* p_in, // input image + void* p_out, // gemm form + const ck::index_t G, + const ck::index_t N, + const ck::index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& image_g_n_c_wis_strides, + const std::array& gemm_g_m_k_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads) override + { + return std::make_unique(static_cast(p_in), + static_cast(p_out), + G, + N, + C, + input_spatial_lengths, + filter_spatial_lengths, + output_spatial_lengths, + image_g_n_c_wis_strides, + gemm_g_m_k_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceImageToColumn" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << KPerBlock << ", " + << ScalarPerVector + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp new file mode 100755 index 000000000000..f86b181e7e12 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp @@ -0,0 +1,366 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/device_max_pool_bwd.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd +{ + using DInDataType_AutomicAddPreCast = + conditional_t || is_same_v, + DInDataType, + float>; + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using UnaryConvert = ck::tensor_operation::element_wise::UnaryConvert; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template + static auto PadDescriptor_M_1d(Desc_M& desc_m, index_t loop_step) + { + const auto m = desc_m.GetLength(I0); + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t loop_step) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, loop_step); + } + + template + static auto ExpendDescFirstDim(Desc_M desc_m) + { + return transform_tensor_descriptor( + desc_m, + make_tuple(make_unmerge_transform(make_tuple(I1, desc_m.GetLength(I0)))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + + using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1)); + using InOutGrid2dDesc = decltype(ExpendDescFirstDim(InOutGrid1dDesc{})); + + using GridwisePutElementSet = GridwisePutElement_1D; + + using GridwisePutElementAtomicAdd = GridwisePutElement_1D; + + static constexpr index_t BlockSize = 256; + static constexpr index_t MPerThread = 1; + static constexpr index_t NPerThread = InOutVectorSize; + static constexpr index_t MPerBlock = 1; + static constexpr index_t NPerBlock = BlockSize * NPerThread; + + using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt; + + using GridwiseCasting = GridwiseElementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + UnaryConvert, + BlockSize, + MPerBlock, + NPerBlock, + MPerThread, + NPerThread, + Sequence<0, 1>, + Sequence, + Sequence, + I1, + I1>; + + struct Argument : public BaseArgument + { + Argument(const DOutDataType* p_dout, + const IndexDataType* p_indices, + DInDataType* p_din, + index_t dout_length, + index_t din_length, + const std::vector& window_lengths, + const std::vector& window_strides, + const std::vector& window_dilations) + : p_dout_{p_dout}, + p_indices_{p_indices}, + p_din_{p_din}, + dout_length_raw_{dout_length}, + din_length_raw_{din_length}, + blockSize_{BlockSize}, + windowOverlap_{false} + { + for(size_t i = 0; i < window_lengths.size(); ++i) + { + auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1; + windowOverlap_ |= eff > window_strides.at(i); + } + } + + const DOutDataType* p_dout_; + const IndexDataType* p_indices_; + DInDataType* p_din_; + index_t dout_length_raw_; + index_t din_length_raw_; + index_t blockSize_; + bool windowOverlap_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize; + InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step); + InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step); + + if constexpr(is_same_v || is_same_v) + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + if(arg.windowOverlap_) + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + else + { + const auto put_kernel = kernel_put_element_1d; + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + else + { + if(arg.windowOverlap_) + { + if(arg.p_workspace_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + hip_check_error( + hipMemsetAsync(arg.p_workspace_, + 0, + arg.din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + const auto cast_kernel = + kernel_elementwise, + Tuple, + Tuple, + Tuple, + Block2TileMap, + UnaryConvert>; + + float elapsed_time = launch_and_time_kernel( + stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + static_cast(arg.p_workspace_), + PassThrough{}); + + InOutGrid2dDesc din_grid_desc_2d = ExpendDescFirstDim(din_grid_desc); + const index_t M = din_grid_desc_2d.GetLength(I0); + const index_t N = din_grid_desc_2d.GetLength(I1); + const auto block_2_tile_map = Block2TileMap(M, N); + const auto cast_kernel_grid_size = + block_2_tile_map.CalculateGridSize(din_grid_desc_2d); + + elapsed_time += launch_and_time_kernel( + stream_config, + cast_kernel, + dim3(cast_kernel_grid_size), + dim3(arg.blockSize_), + 0, + ck::make_tuple(din_grid_desc_2d), + ck::make_tuple(din_grid_desc_2d), + ck::make_tuple( + static_cast(arg.p_workspace_)), + ck::make_tuple(arg.p_din_), + block_2_tile_map, + UnaryConvert{}); + + return elapsed_time; + } + else + { + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + const auto put_kernel = kernel_put_element_1d; + + hip_check_error(hipMemsetAsync(arg.p_din_, + 0, + arg.din_length_raw_ * sizeof(DInDataType), + stream_config.stream_id_)); + + return launch_and_time_kernel(stream_config, + put_kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + dout_grid_desc, + arg.p_dout_, + arg.p_indices_, + arg.p_din_, + PassThrough{}); + } + } + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + bool needCast = pArg_->windowOverlap_ && + !(is_same_v || is_same_v); + + if(!needCast) + return 0; + else + return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast); + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + if(pArg->din_length_raw_ % InOutVectorSize != 0 || + pArg->dout_length_raw_ % InOutVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr + MakeArgumentPointer(const void* p_dout, + const void* p_indices, + void* p_din, + index_t dout_length, + index_t din_length, + std::vector window_lengths, + std::vector window_strides, + std::vector window_dilations) override + { + // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are + // physical size of the packed tensor + return std::make_unique(static_cast(p_dout), + static_cast(p_indices), + static_cast(p_din), + dout_length, + din_length, + window_lengths, + window_strides, + window_dilations); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp new file mode 100755 index 000000000000..27d3c378aceb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -0,0 +1,565 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemm : public DeviceGemmMultipleDSplitKBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = + GridwiseMoeGemm; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / + 4 * (1 + GridwiseGemm::NWave); + constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / + 4 * (2) * (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) / + BlockSize / 4 * (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_total = + estimated_reg_a + estimated_reg_b + estimated_reg_c; + + constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#if 1 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#endif + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, {BlockGemmPipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceMoeGEmm" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemmBlockScale + : public DeviceGemmMultipleD_BlockScale_BPreshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = GridwiseMoeGemmBlockScale< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + BDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB, + LDSTypeA, + LDSTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / BPackedSize; + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize / + 4 * (1 + GridwiseGemm::NWave); + constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / + 4 * (2) * (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) / + BlockSize / 4 * (IsInputGemm ? 2 : 1); + constexpr auto estimated_reg_total = + estimated_reg_a + estimated_reg_b + estimated_reg_c; + + constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v2 support now"); + } + } +#if 1 + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + } +#endif + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + // index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_b), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideC, + static_cast(p_a_scale), + static_cast(p_b_scale), + 1, // KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}}; + + // clang-format off + str << "DeviceMoeGEmm" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = + GridwiseMoeGemmMX; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // TODO: Check if this is the right algorithm for minimum_occupancy + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v3 support now"); + } + } + else + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceMoeGEmmMx" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = + GridwiseMoeGemmMXBNS; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // TODO: Check if this is the right algorithm for minimum_occupancy + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v3 support now"); + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceMoeGEmmMx" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + using GridwiseGemm = GridwiseMoeGemmMX_BPreshuffle< + ALayout, + BLayout, + DsLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ActivationOP, + NSwizzle, + IsInputGemm, + MulRoutedWeight, + IndexType, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + int GetPreShuffleParameters() override { return NPerXDL; } + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto RunKernel = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + + std::array DsSize; + + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( + arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); + }); + ck::utility::RotatingMemWrapperMultiD rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // TODO: Check if this is the right algorithm for minimum_occupancy + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + constexpr auto MemoryDataOp = + IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + } + else + { + throw std::runtime_error("todo: only v1 & v3 support now"); + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm; + RunKernel(kernel); + } + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_mxgemm_2lds; + RunKernel(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + // only impl kbatch 1 now + if(arg.KBatch > 1) + { + return false; + } + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_sorted_token_ids, + const void* p_sorted_expert_ids, + const void* p_max_token_id, + const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t NumTokens, + index_t TopK, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{static_cast(p_sorted_token_ids), + static_cast(p_sorted_expert_ids), + static_cast(p_max_token_id), + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + NumTokens, + TopK, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + std::array StrideDs, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(nullptr, + nullptr, + nullptr, + static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + p_ds, + static_cast(p_c), + M, // randoms set, no use + 0, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideDs, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceMoeGEmmMx" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp" +#include "ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Multi-Query Attention (MQA) kernel implementation +// Assume number of head of K,V is 1. +// Q [G0, G1, M, K] * K [G0, 1, K, N] = P [G0, G1, M, N] +// P [G0, G1, M, N] * V [G0, 1, N, O] = Out [G0, G1, M, O] +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_multi_query_attention_wmma(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + index_t M, // SequenceQ + index_t N, // SequenceK + index_t K, // HeadDim + index_t O, // SequenceK + index_t G0, // Batch + index_t G1, // HeadNum + float alpha, + bool input_permute, + bool output_permute) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + + // clang-format off +// *************************************************** + const auto q_head = G1; + const auto kv_head = 1; +// Make Tensor Descriptors + constexpr index_t array_size = 4; + std::array a_gs_ms_ks_lengths{G0, q_head, M, K}; + std::array a_gs_ms_ks_strides = + input_permute + ? std::array{M * q_head * K, K, q_head * K, 1} // A layout [G0, M, G1, K] + : std::array{q_head * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, kv_head, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute + ? std::array{N * kv_head * K, K, kv_head * K, 1} // B0 layout [G0, N, 1, K] + : std::array{kv_head * N * K, N * K, K, 1}; // B0 layout [G0, 1, N, K] + + std::array b1_gs_os_ns_lengths{G0, kv_head, O, N}; + std::array b1_gs_os_ns_strides = + input_permute + ? std::array{N * kv_head * O, O, 1, kv_head * O} // B1 layout [G0, N, 1, O] + : std::array{kv_head * N * O, N * O, 1, O}; // B1 layout [G0, 1, N, O] + + std::array c_gs_ms_os_lengths{G0, q_head, M, O}; + std::array c_gs_ms_os_strides = + output_permute + ? std::array{M * q_head * O, O, q_head * O, 1} // C layout [G0, M, G1, O] + : std::array{q_head * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_element_op = AElementwiseOperation{}; + const auto b0_element_op = B0ElementwiseOperation{}; + const auto acc0_element_op = AccElementwiseOperation{alpha}; + const auto b1_element_op = B1ElementwiseOperation{}; + const auto c_element_op = CElementwiseOperation{}; + // fail to reuse DeviceOp::MakeArgument() because of the __device__ function required. + + const auto a_grid_desc = DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto a_grid_desc_g_m_k = + DeviceOp::Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc_g_l_k = + DeviceOp::Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc_g_n_l = + DeviceOp::Transform::MakeB1GridDescriptor_G_N_K(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + const auto compute_base_ptr_of_batch = + typename DeviceOp::ComputeBasePtrOfStridedBatch{a_grid_desc_g_m_k, b0_grid_desc_g_l_k, b1_grid_desc_g_n_l, c_grid_desc_g_m_n}; + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + const auto c0_matrix_mask = typename DeviceOp::C0MatrixMask{b0_grid_desc_g_l_k.GetLength(Number<1>{})}; + + // clang-format on + __shared__ char p_shared[GridwiseOp::GetSharedMemoryNumberOfByte()]; + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetABasePtr(g_idx))); + const long_index_t b0_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB0BasePtr(g_idx / G1))); + const long_index_t b1_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetB1BasePtr(g_idx / G1))); + const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_base_ptr_of_batch.GetCBasePtr(g_idx))); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b0_grid + b0_batch_offset, + p_b1_grid + b1_batch_offset, + p_c_grid + c_batch_offset, + p_shared, + a_grid_desc, + b0_grid_desc, + b1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b0_element_op, + acc0_element_op, + b1_element_op, + c_element_op, + c0_matrix_mask, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b0_grid; + ignore = p_b1_grid; + ignore = p_c_grid; + ignore = M; + ignore = N; + ignore = K; + ignore = O; + ignore = G0; + ignore = G1; + ignore = alpha; + ignore = input_permute; + ignore = output_permute; +#endif // end of if (defined(__gfx11__)) +} + +// Computes C = A * B0 * B1 +// MN = MK * KL * LN +// ^^^^^^ (Acc0) +// ^^^^^^^^^^^ (Acc1) +template +struct DeviceMultiQueryAttentionForward_Wmma + : public DeviceBatchedGemmSoftmaxGemmPermute +{ + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimL > 0 && NumDimK > 0 && NumDimN > 0, + "Number of dimension must be greater than 0"); + + static constexpr index_t NumAcc0Bias = Acc0BiasDataType::Size(); + static constexpr index_t NumAcc1Bias = Acc1BiasDataType::Size(); + + // TODO ANT: implement bias combination + static_assert(NumAcc0Bias == 0 && NumAcc0Bias == 0, "Bias addition is unimplemented"); + + static constexpr index_t NumDimGemm0M = NumDimM; + static constexpr index_t NumDimGemm0N = NumDimL; + static constexpr index_t NumDimGemm0K = NumDimK; + static constexpr index_t NumDimGemm1M = NumDimM; + static constexpr index_t NumDimGemm1N = NumDimN; + static constexpr index_t NumDimGemm1K = NumDimL; + + using DeviceOp = DeviceMultiQueryAttentionForward_Wmma; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + + static constexpr auto WmmaK = 16; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + + static constexpr auto AEnableLds_auto = LWaves == 1 ? false : true; + static constexpr auto B0EnableLds_auto = MWaves == 1 ? false : true; + static constexpr auto B1EnableLds_auto = MWaves == 1 ? false : true; + + static constexpr auto AEnableLds_manu = false; + static constexpr auto B0EnableLds_manu = true; + static constexpr auto B1EnableLds_manu = true; + + static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1); + static constexpr auto B0EnableLds = B0EnableLds_auto || B0EnableLds_manu || (NumPrefetch > 1); + static constexpr auto B1EnableLds = B1EnableLds_auto || B1EnableLds_manu || (NumPrefetch > 1); + + using Transform = TransformBatchedContractionContractionToBatchedGemmGemm_Wmma< + Sequence, + Sequence, + GemmSpec, + ASpec, + B0Spec, + B1Spec, + CSpec>; + + __host__ __device__ static auto MakeAGridDescriptor( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + if constexpr(AEnableLds) + { + return Transform::MakeAGridDescriptor_AK0_M_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + Transform::MakeAGridDescriptor_M_K(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB0GridDescriptor( + const std::array& b0_gs_ls_ks_lengths_vec, + const std::array& b0_gs_ls_ks_strides_vec) + { + if constexpr(B0EnableLds) + { + return Transform::MakeB0GridDescriptor_BK0_N_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + Transform::MakeB0GridDescriptor_N_K(b0_gs_ls_ks_lengths_vec, + b0_gs_ls_ks_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + __host__ __device__ static auto MakeB1GridDescriptor( + const std::array& b1_gs_ns_ls_lengths_vec, + const std::array& b1_gs_ns_ls_strides_vec) + { + if constexpr(B1EnableLds) + { + return Transform::MakeB1GridDescriptor_BK0_N_BK1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}); + } + else + { + return Transform:: + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + Transform::MakeB1GridDescriptor_N_K(b1_gs_ns_ls_lengths_vec, + b1_gs_ns_ls_strides_vec), + Number{}, + Number{}, + Number{}, + Number{}, + Number{}); + } + } + + using AGridDesc = decltype(MakeAGridDescriptor({}, {})); + using B0GridDesc = decltype(MakeB0GridDescriptor({}, {})); + using B1GridDesc = decltype(MakeB1GridDescriptor({}, {})); + using CGridDesc_M_N = decltype(Transform::MakeCGridDescriptor_M_N({}, {})); + using AGridDesc_G_M_K = decltype(Transform::MakeAGridDescriptor_G_M_K({}, {})); + using B0GridDesc_G_L_K = decltype(Transform::MakeB0GridDescriptor_G_N_K({}, {})); + using B1GridDesc_G_N_L = decltype(Transform::MakeB1GridDescriptor_G_N_K({}, {})); + using CGridDesc_G_M_N = decltype(Transform::MakeCGridDescriptor_G_M_N({}, {})); + + __host__ __device__ constexpr static auto make_MaskOutPredicate() + { + if constexpr(MaskingSpec == MaskingSpecialization::MaskDisabled) + { + return MaskDisabledPredicate{}; + } + else if constexpr(MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle) + { + return MaskOutUpperTrianglePredicate{}; + } + } + using C0MatrixMask = C0MatrixMask_impl; + + struct ComputeBasePtrOfStridedBatch + { + __host__ __device__ ComputeBasePtrOfStridedBatch(const AGridDesc_G_M_K& a_grid_desc_g_m_k, + const B0GridDesc_G_L_K& b0_grid_desc_g_l_k, + const B1GridDesc_G_N_L& b1_grid_desc_g_n_l, + const CGridDesc_G_M_N& c_grid_desc_g_m_n) + : a_grid_desc_g_m_k_(a_grid_desc_g_m_k), + b0_grid_desc_g_l_k_(b0_grid_desc_g_l_k), + b1_grid_desc_g_n_l_(b1_grid_desc_g_n_l), + c_grid_desc_g_m_n_(c_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetABasePtr(index_t g_idx) const + { + return a_grid_desc_g_m_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB0BasePtr(index_t g_idx) const + { + return b0_grid_desc_g_l_k_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetB1BasePtr(index_t g_idx) const + { + return b1_grid_desc_g_n_l_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + __host__ __device__ constexpr long_index_t GetCBasePtr(index_t g_idx) const + { + return c_grid_desc_g_m_n_.CalculateOffset(make_multi_index(g_idx, 0, 0)); + } + + private: + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + }; + + // GridwiseOp + using GridwiseOp = GridwiseBatchedGemmSoftmaxGemm_Wmma< + // DataType Family + ADataType, + B0DataType, + Acc0DataType, + B1DataType, + Acc1DataType, + CShuffleDataType, + CDataType, + // ElementwiseOp Family + AElementwiseOperation, + B0ElementwiseOperation, + AccElementwiseOperation, + B1ElementwiseOperation, + CElementwiseOperation, + InMemoryDataOperationEnum::Set, + // InMemory Data Descriptor + AGridDesc, + B0GridDesc, + B1GridDesc, + CGridDesc_M_N, + // Tiling Family + MPerBlock, + LPerBlock, + KPerBlock, + AK1, + BK1, + NPerBlock, + LTilePerBlock, + L1, + MPerWmma, + LPerWmma, + NPerWmma, + MRepeat, + LRepeat, + NRepeat, + // ThreadCluster Family + BlockSize, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + true, + AEnableLds, + ABlockLdsAddExtraM, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0BlockTransferSrcAccessOrder, + B0BlockTransferSrcVectorDim, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + true, + B0EnableLds, + B0BlockLdsAddExtraL, + B1BlockTransferThreadClusterLengths_L0_N_L1, + B1BlockTransferThreadClusterArrangeOrder, + B1BlockTransferSrcAccessOrder, + B1BlockTransferSrcVectorDim, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_L1, + false, + B1EnableLds, + B1BlockLdsAddExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + Transform::matrix_padder.PadN, + MaskingSpec == MaskingSpecialization::MaskOutUpperTriangle, + NumPrefetch, + LoopSched, + PipelineVer>; + + struct RawArg : public BaseArgument + { + RawArg(const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + M_{M}, + N_{N}, + K_{K}, + O_{O}, + G0_{G0}, + G1_{G1}, + alpha_{alpha}, + input_permute_{input_permute}, + output_permute_{output_permute} + { + } + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Raw Problem Size + index_t M_; + index_t N_; + index_t K_; + index_t O_; + index_t G0_; + index_t G1_; + float alpha_; + bool input_permute_; + bool output_permute_; + }; + + static auto MakeArgument(const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t O, + index_t G0, + index_t G1, + float alpha, + bool input_permute, + bool output_permute) + { + return RawArg{ + p_a, p_b0, p_b1, p_c, M, N, K, O, G0, G1, alpha, input_permute, output_permute}; + } + + static bool IsSupportedArgument(const RawArg& arg) + { + if(ck::is_gfx11_supported() || ck::is_gfx12_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + constexpr index_t array_size = 4; + ck::index_t G0 = arg.G0_; + ck::index_t G1 = arg.G1_; + ck::index_t M = arg.M_; + ck::index_t N = arg.N_; + ck::index_t K = arg.K_; + ck::index_t O = arg.O_; + bool input_permute = arg.input_permute_; + bool output_permute = arg.output_permute_; + + std::array a_gs_ms_ks_lengths{G0, G1, M, K}; + std::array a_gs_ms_ks_strides = + input_permute ? std::array{M * G1 * K, K, G1 * K, 1} + // A layout [G0, M, G1, K] + : std::array{ + G1 * M * K, M * K, K, 1}; // A layout [G0, G1, M, K] + + std::array b0_gs_ns_ks_lengths{G0, G1, N, K}; + std::array b0_gs_ns_ks_strides = + input_permute ? std::array{N * G1 * K, K, G1 * K, 1} + // B0 layout [G0, N, G1, K] + : std::array{ + G1 * N * K, N * K, K, 1}; // B0 layout [G0, G1, N, K] + + std::array b1_gs_os_ns_lengths{G0, G1, O, N}; + std::array b1_gs_os_ns_strides = + input_permute ? std::array{N * G1 * O, O, 1, G1 * O} + // B1 layout [G0, N, G1, O] + : std::array{ + G1 * N * O, N * O, 1, O}; // B1 layout [G0, G1, N, O] + + std::array c_gs_ms_os_lengths{G0, G1, M, O}; + std::array c_gs_ms_os_strides = + output_permute ? std::array{M * G1 * O, O, G1 * O, 1} + // C layout [G0, M, G1, O] + : std::array{ + G1 * M * O, M * O, O, 1}; // C layout [G0, G1, M, O] + + const auto a_grid_desc = + DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides); + const auto b0_grid_desc = + DeviceOp::MakeB0GridDescriptor(b0_gs_ns_ks_lengths, b0_gs_ns_ks_strides); + const auto b1_grid_desc = + DeviceOp::MakeB1GridDescriptor(b1_gs_os_ns_lengths, b1_gs_os_ns_strides); + const auto c_grid_desc_m_n = + DeviceOp::Transform::MakeCGridDescriptor_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + + const auto block_2_ctile_map = GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, 1, 1); + + const auto c_grid_desc_g_m_n = + DeviceOp::Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_os_lengths, c_gs_ms_os_strides); + index_t batch_count = c_grid_desc_g_m_n.GetLength(Number<0>{}); + + if(!GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n, block_2_ctile_map)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = c_grid_desc_g_m_n.GetLength(I0); // unpadded + + if(!(c_g == batch_count)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = M; + const auto LzRaw = N; + const auto KzRaw = K; + const auto NzRaw = O; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + std::array a_mz_kz_strides_{ + a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}; + std::array b0_lz_kz_strides_{ + b0_gs_ns_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ns_ks_strides[NumDimG + NumDimL + NumDimK - 1]}; + std::array b1_nz_lz_strides_{ + b1_gs_os_ns_strides[NumDimG + NumDimN - 1], + b1_gs_os_ns_strides[NumDimG + NumDimN + NumDimL - 1]}; + std::array c_mz_nz_strides_{ + c_gs_ms_os_strides[NumDimG + NumDimM - 1], + c_gs_ms_os_strides[NumDimG + NumDimM + NumDimN - 1]}; + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? a_mz_kz_strides_[1] : a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? b0_lz_kz_strides_[1] : b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? b1_nz_lz_strides_[1] : b1_nz_lz_strides_[0]; + const auto c_stride_lowest = c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + // Argument + struct Argument : public BaseArgument + { + Argument( + const ADataType* p_a_grid, + const B0DataType* p_b0_grid, + const B1DataType* p_b1_grid, + CDataType* p_c_grid, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + const index_t M01, + const index_t N01, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + : p_a_grid_{p_a_grid}, + p_b0_grid_{p_b0_grid}, + p_b1_grid_{p_b1_grid}, + p_c_grid_{p_c_grid}, + a_grid_desc{DeviceOp::MakeAGridDescriptor(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc{ + DeviceOp::MakeB0GridDescriptor(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc{ + DeviceOp::MakeB1GridDescriptor(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_m_n_{ + Transform::MakeCGridDescriptor_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + a_grid_desc_g_m_k_{ + Transform::MakeAGridDescriptor_G_M_K(a_gs_ms_ks_lengths, a_gs_ms_ks_strides)}, + b0_grid_desc_g_l_k_{ + Transform::MakeB0GridDescriptor_G_N_K(b0_gs_ls_ks_lengths, b0_gs_ls_ks_strides)}, + b1_grid_desc_g_n_l_{ + Transform::MakeB1GridDescriptor_G_N_K(b1_gs_ns_ls_lengths, b1_gs_ns_ls_strides)}, + c_grid_desc_g_m_n_{ + Transform::MakeCGridDescriptor_G_M_N(c_gs_ms_ns_lengths, c_gs_ms_ns_strides)}, + c_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_ctile_map_{GridwiseOp::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01)}, + a_element_op_{a_element_op}, + b0_element_op_{b0_element_op}, + acc_element_op_{acc_element_op}, + b1_element_op_{b1_element_op}, + c_element_op_{c_element_op}, + c0_matrix_mask_{b0_grid_desc_g_l_k_.GetLength(I1)}, + raw_lengths_mz_lz_kz_nz_{a_gs_ms_ks_lengths[NumDimG + NumDimM - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL - 1], + b0_gs_ls_ks_lengths[NumDimG + NumDimL + NumDimK - 1], + b1_gs_ns_ls_lengths[NumDimG + NumDimN - 1]}, + a_mz_kz_strides_{a_gs_ms_ks_strides[NumDimG + NumDimM - 1], + a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]}, + b0_lz_kz_strides_{b0_gs_ls_ks_strides[NumDimG + NumDimL - 1], + b0_gs_ls_ks_strides[NumDimG + NumDimL + NumDimK - 1]}, + b1_nz_lz_strides_{b1_gs_ns_ls_strides[NumDimG + NumDimN - 1], + b1_gs_ns_ls_strides[NumDimG + NumDimN + NumDimL - 1]}, + c_mz_nz_strides_{c_gs_ms_ns_strides[NumDimG + NumDimM - 1], + c_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]}, + batch_count_{c_grid_desc_g_m_n_.GetLength(I0)}, + compute_ptr_offset_of_batch_{ + a_grid_desc_g_m_k_, b0_grid_desc_g_l_k_, b1_grid_desc_g_n_l_, c_grid_desc_g_m_n_} + { + // TODO ANT: implement bias addition + ignore = p_acc0_biases; + ignore = p_acc1_biases; + ignore = acc0_biases_gs_ms_ls_lengths; + ignore = acc0_biases_gs_ms_ls_strides; + ignore = acc1_biases_gs_ms_ns_lengths; + ignore = acc1_biases_gs_ms_ns_strides; + + if(GridwiseOp::CheckValidity( + a_grid_desc, b0_grid_desc, b1_grid_desc, c_grid_desc_m_n_, block_2_ctile_map_)) + { + c_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseOp::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n_); + } + } + + // Pointers + const ADataType* p_a_grid_; + const B0DataType* p_b0_grid_; + const B1DataType* p_b1_grid_; + CDataType* p_c_grid_; + + // Tensor Descriptors + AGridDesc a_grid_desc; + B0GridDesc b0_grid_desc; + B1GridDesc b1_grid_desc; + CGridDesc_M_N c_grid_desc_m_n_; + + AGridDesc_G_M_K a_grid_desc_g_m_k_; + B0GridDesc_G_L_K b0_grid_desc_g_l_k_; + B1GridDesc_G_N_L b1_grid_desc_g_n_l_; + CGridDesc_G_M_N c_grid_desc_g_m_n_; + + typename GridwiseOp::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock_; + + // Block to Tile mapping + typename GridwiseOp::DefaultBlock2CTileMap block_2_ctile_map_; + + // ElementwiseOp + AElementwiseOperation a_element_op_; + B0ElementwiseOperation b0_element_op_; + AccElementwiseOperation acc_element_op_; + B1ElementwiseOperation b1_element_op_; + CElementwiseOperation c_element_op_; + + // check C0 masking and padding + C0MatrixMask c0_matrix_mask_; + + // Strides for the last M/N/K dimensions of A/B0/B1/C + // for sanity check of vector load/store + std::array raw_lengths_mz_lz_kz_nz_; + std::array a_mz_kz_strides_; + std::array b0_lz_kz_strides_; + std::array b1_nz_lz_strides_; + std::array c_mz_nz_strides_; + + index_t batch_count_; + // Batch Offset + ComputeBasePtrOfStridedBatch compute_ptr_offset_of_batch_; + }; + + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::RawArg; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto M0 = math::integer_divide_ceil(arg.M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(arg.O_, NPerBlock); + + const index_t grid_size = arg.G0_ * arg.G1_ * M0 * N0; + const auto K = arg.K_; + // printf("HasKBlockLoop: %d\n", GridwiseOp::CalculateHasMainKBlockLoop(K)); + auto launch_kernel = [&](auto has_main_k_block_loop) { + const auto kernel = kernel_multi_query_attention_wmma; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b0_grid_, + arg.p_b1_grid_, + arg.p_c_grid_, + arg.M_, + arg.N_, + arg.K_, + arg.O_, + arg.G0_, + arg.G1_, + arg.alpha_, + arg.input_permute_, + arg.output_permute_); + }; + + if(GridwiseOp::CalculateHasMainKBlockLoop(K)) + { + return launch_kernel(integral_constant{}); + } + else + { + return launch_kernel(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } +#if 0 + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::is_gfx11_supported()) + { + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc0 Type err"); + return false; + } + + if constexpr(!(is_same_v || is_same_v)) + { + printf("DeviceOp: Acc1 Type err"); + return false; + } + } + else + { + printf("DeviceOp: Arch err"); + return false; + } + + if(!GridwiseOp::CheckValidity(arg.a_grid_desc, + arg.b0_grid_desc, + arg.b1_grid_desc, + arg.c_grid_desc_m_n_, + arg.block_2_ctile_map_)) + { + return false; + } + + // Check if C permute dimension matches GEMM + GEMM shape + const index_t c_g = arg.c_grid_desc_g_m_n_.GetLength(I0); // unpadded + + if(!(c_g == arg.batch_count_)) + { + printf("DeviceOp: BatchCount err"); + return false; + } + + // Note: we need raw lengths since threadwise copy can not handle vector load when part of + // vector is out of bounds + // Note: need lowest dim in Ms/Ns/Ks/Os, not merged M/N/K/O + const auto MzRaw = arg.raw_lengths_mz_lz_kz_nz_[0]; + const auto LzRaw = arg.raw_lengths_mz_lz_kz_nz_[1]; + const auto KzRaw = arg.raw_lengths_mz_lz_kz_nz_[2]; + const auto NzRaw = arg.raw_lengths_mz_lz_kz_nz_[3]; + + // Check scalar per vector requirement + const auto a_extent_lowest = ABlockTransferSrcVectorDim == 2 ? KzRaw : MzRaw; + const auto b0_extent_lowest = B0BlockTransferSrcVectorDim == 2 ? KzRaw : LzRaw; + const auto b1_extent_lowest = B1BlockTransferSrcVectorDim == 2 ? LzRaw : NzRaw; + const auto c_extent_lowest = NzRaw; + + if(!(a_extent_lowest % ABlockTransferSrcScalarPerVector == 0 && + b0_extent_lowest % B0BlockTransferSrcScalarPerVector == 0 && + b1_extent_lowest % B1BlockTransferSrcScalarPerVector == 0 && + c_extent_lowest % CShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + { + printf("DeviceOp: Data Transfer Vector scalar err"); + return false; + } + + // Check vector load/store requirement + const auto a_stride_lowest = + ABlockTransferSrcVectorDim == 2 ? arg.a_mz_kz_strides_[1] : arg.a_mz_kz_strides_[0]; + const auto b0_stride_lowest = + B0BlockTransferSrcVectorDim == 2 ? arg.b0_lz_kz_strides_[1] : arg.b0_lz_kz_strides_[0]; + const auto b1_stride_lowest = + B1BlockTransferSrcVectorDim == 2 ? arg.b1_nz_lz_strides_[1] : arg.b1_nz_lz_strides_[0]; + const auto c_stride_lowest = arg.c_mz_nz_strides_[1]; + + if(!(a_stride_lowest == 1 || b0_stride_lowest == 1 || b1_stride_lowest == 1 || + c_stride_lowest == 1)) + { + printf("DeviceOp: Data Vectorize transfer err"); + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument( + const ADataType* p_a, + const B0DataType* p_b0, + const B1DataType* p_b1, + CDataType* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::array& a_gs_ms_ks_lengths, + const std::array& a_gs_ms_ks_strides, + const std::array& b0_gs_ls_ks_lengths, + const std::array& b0_gs_ls_ks_strides, + const std::array& b1_gs_ns_ls_lengths, + const std::array& b1_gs_ns_ls_strides, + const std::array& c_gs_ms_ns_lengths, + const std::array& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b0, + p_b1, + p_c, + p_acc0_biases, + p_acc1_biases, + a_gs_ms_ks_lengths, + a_gs_ms_ks_strides, + b0_gs_ls_ks_lengths, + b0_gs_ls_ks_strides, + b1_gs_ns_ls_lengths, + b1_gs_ns_ls_strides, + c_gs_ms_ns_lengths, + c_gs_ms_ns_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op}; + } +#endif + + // polymorphic + std::unique_ptr MakeArgumentPointer( + const void* p_a, + const void* p_b0, + const void* p_b1, + void* p_c, + const std::array p_acc0_biases, + const std::array p_acc1_biases, + const std::vector& a_gs_ms_ks_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b0_gs_ls_ks_lengths, + const std::vector& b0_gs_ls_ks_strides, + const std::vector& b1_gs_ns_ls_lengths, + const std::vector& b1_gs_ns_ls_strides, + const std::vector& c_gs_ms_ns_lengths, + const std::vector& c_gs_ms_ns_strides, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_lengths, + const std::array, NumAcc0Bias> acc0_biases_gs_ms_ls_strides, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_lengths, + const std::array, NumAcc1Bias> acc1_biases_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + B0ElementwiseOperation b0_element_op, + AccElementwiseOperation acc_element_op, + B1ElementwiseOperation b1_element_op, + CElementwiseOperation c_element_op) override + { + std::array a_lengths; + std::array a_strides; + std::array b0_lengths; + std::array b0_strides; + std::array b1_lengths; + std::array b1_strides; + std::array c_lengths; + std::array c_strides; + std::transform(a_gs_ms_ks_lengths.begin(), + a_gs_ms_ks_lengths.end(), + a_lengths.begin(), + [](index_t i) { return i; }); + std::transform(a_gs_ms_ks_strides.begin(), + a_gs_ms_ks_strides.end(), + a_strides.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_lengths.begin(), + b0_gs_ls_ks_lengths.end(), + b0_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b0_gs_ls_ks_strides.begin(), + b0_gs_ls_ks_strides.end(), + b0_strides.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_lengths.begin(), + b1_gs_ns_ls_lengths.end(), + b1_lengths.begin(), + [](index_t i) { return i; }); + std::transform(b1_gs_ns_ls_strides.begin(), + b1_gs_ns_ls_strides.end(), + b1_strides.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_lengths.begin(), + c_gs_ms_ns_lengths.end(), + c_lengths.begin(), + [](index_t i) { return i; }); + std::transform(c_gs_ms_ns_strides.begin(), + c_gs_ms_ns_strides.end(), + c_strides.begin(), + [](index_t i) { return i; }); + return std::make_unique(static_cast(p_a), + static_cast(p_b0), + static_cast(p_b1), + static_cast(p_c), + p_acc0_biases, + p_acc1_biases, + a_lengths, + a_strides, + b0_lengths, + b0_strides, + b1_lengths, + b1_strides, + c_lengths, + c_strides, + acc0_biases_gs_ms_ls_lengths, + acc0_biases_gs_ms_ls_strides, + acc1_biases_gs_ms_ns_lengths, + acc1_biases_gs_ms_ns_strides, + 1, + 1, + a_element_op, + b0_element_op, + acc_element_op, + b1_element_op, + c_element_op); + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map LoopSchedToString{ + {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}}; + + std::map PipelineVersionToString{{PipelineVersion::v1, "v1"}, + {PipelineVersion::v2, "v2"}}; + + // clang-format off + str << "DeviceMultiQueryAttentionForward_Wmma" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << LPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << LTilePerBlock << ", " + << L1 << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << "ASpec" << getTensorSpecializationString(ASpec) << ", " + << "B0Spec" << getTensorSpecializationString(B0Spec) << ", " + << "B1Spec" << getTensorSpecializationString(B1Spec) << ", " + << "CSpec" << getTensorSpecializationString(CSpec) << ", " + << getMaskingSpecializationString(MaskingSpec) + << ">" + << " AEnableLds: " + << AEnableLds << ", " + << "B0EnableLds: " + << B0EnableLds << ", " + << "B1EnableLds: " + << B1EnableLds << ", " + << "NumPrefetch: " + << NumPrefetch << ", " + << "LoopScheduler: " + << LoopSchedToString[LoopSched] << ", " + << "PipelineVersion: " + << PipelineVersionToString[PipelineVer]; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp new file mode 100755 index 000000000000..aec5a65ccf41 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_multiblock.hpp @@ -0,0 +1,595 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/sequence.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMultipleReduceMultiBlock : public DeviceMultipleReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypeTuple::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static_assert(sequence_all_of(OutDstVectorSizeSeq{}, + [](auto vectorSize) { + return (MThreadSliceSize % vectorSize == 0); + }), + "The OutDstVectorSize should completely divide the MThreadSliceSize!"); + + static constexpr bool CheckDataTypeTuple() + { + bool flag = true; + + static_for<0, NumReduction, 1>{}([&](auto I) { + using OutDataType = remove_cvref_t; + flag = + flag && ck::reduce::InMemoryDataOperationSupportedOnDataType::value; + }); + + return flag; + }; + + static_assert(CheckDataTypeTuple(), + "The OutDataType must support the specified OutMemoryDataOperation!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumInputDim = Rank; + static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added + // later + static constexpr bool use_multiblock = + (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); + + static_assert( + ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation), + "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto GenerateOutGrid1dDescTuple() + { + return generate_tuple( + [&](auto I) { + (void)I; + return MakeDst1dDescriptor(std::array{}, + std::array{}); + }, + Number{}); + }; + + using InGridDesc_M_K = decltype(MakeSrc2dDescriptor( + std::array{}, std::array{}, 1, 1)); + using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple()); + + static auto MakeDst1dDescriptorForBufferSet(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto length = out_grid_desc_m.GetLength(Number<0>{}); + + const auto pad = math::integer_least_multiple(length, BlockSize) - length; + + auto out_grid_desc_m_padded = + transform_tensor_descriptor(out_grid_desc_m, + make_tuple(make_right_pad_transform(length, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto GenerateOutGrid1dDescTuple_2() + { + return generate_tuple( + [&](auto I) { + (void)I; + return MakeDst1dDescriptorForBufferSet(std::array{}, + std::array{}); + }, + Number{}); + }; + + using OutGridDesc_M_Tuple_2 = decltype(GenerateOutGrid1dDescTuple_2()); + + struct Argument : public BaseArgument + { + Argument(const std::array& inLengths, + const std::array& inStrides, + const std::array& outLengths, + const std::array, NumReduction>& outStridesArray, + const std::array& reduceDims, + const std::array& alphas, + const std::array& betas, + const void* in_dev, + const std::array& out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) + : outLengths_{outLengths}, + outStridesArray_{outStridesArray}, + in_elementwise_op_tuple_{in_elementwise_op_tuple}, + acc_elementwise_op_tuple_{acc_elementwise_op_tuple} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + for(size_t i = 0; i < NumReduction; i++) + { + alpha_values_(i) = static_cast(alphas[i]); + beta_values_(i) = static_cast(betas[i]); + }; + + in_dev_ = static_cast(in_dev); + + out_dev_buffers_ = generate_tuple( + [&](auto iR) { + using OutDataTypePointer = + remove_cvref_t; + using OutDataType = remove_cvref_t>; + return static_cast(out_dev_buffers[iR]); + }, + Number{}); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(use_multiblock) + { + + int iterations = 1; + while(true) + { + int testBlkGroupSize = + (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = + (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + in_grid_desc_m_k = + MakeSrc2dDescriptor(inLengths_, inStrides_, blkGroupSize, numBlockTileIteration); + + out_grid_desc_m_tuple = generate_tuple( + [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); }, + Number{}); + + out_grid_desc_m_tuple_2 = generate_tuple( + [&](auto I) { + return MakeDst1dDescriptorForBufferSet(outLengths, outStridesArray[I]); + }, + Number{}); + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + + gridSize_pre = + math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; + } + + std::array inLengths_; + std::array inStrides_; + + std::array outLengths_; + std::array, NumReduction> outStridesArray_; + + Array alpha_values_; + Array beta_values_; + + const InDataType* in_dev_; + OutDataTypePointerTuple out_dev_buffers_; + + InGridDesc_M_K in_grid_desc_m_k; + OutGridDesc_M_Tuple out_grid_desc_m_tuple; + OutGridDesc_M_Tuple_2 out_grid_desc_m_tuple_2; + + InElementwiseOperationTuple in_elementwise_op_tuple_; + AccElementwiseOperationTuple acc_elementwise_op_tuple_; + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + size_t gridSize_pre; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using GridwiseMultipleReduce = + GridwiseMultipleReduction_mk_to_m_multiblock; + + const auto kernel_main = + kernel_multiple_reduce_multiblock; + + float avg_time = 0; + + if constexpr(use_multiblock) + { + auto identity_values = generate_tuple( + [&](auto iR) { + using OutDataType = remove_cvref_t; + return ck::reduce::GetIdentityValueForInMemoryDataOperation( + OutMemoryDataOperation); + }, + Number{}); + + const auto kernel_pre = kernel_multiple_buffer_set_value; + + avg_time += launch_and_time_kernel(stream_config, + kernel_pre, + dim3(arg.gridSize_pre), + dim3(BlockSize), + 0, + arg.out_grid_desc_m_tuple_2, + arg.out_dev_buffers_, + identity_values); + }; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k, + arg.out_grid_desc_m_tuple, + arg.in_elementwise_op_tuple_, + arg.acc_elementwise_op_tuple_, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_values_, + arg.in_dev_, + arg.beta_values_, + arg.out_dev_buffers_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(use_multiblock) + { + for(size_t i = 0; i < pArg->beta_values_.Size(); i++) + if(pArg->beta_values_[i] != 0.0f) + return (false); + }; + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0) + return (false); + }; + // To improve + bool valid = true; + static_for<0, NumReduction, 1>{}([&](auto I) { + if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 && + OutDstVectorSizeSeq::At(I) != 1) + valid = false; + + if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0) + valid = false; + }); + + if(!valid) + return (false); + + if constexpr(use_multiblock) + { + // blkGroupSize of 1 should be handled by Blockwise path using + // InMemoryDataOperationEnum::Set + if(pArg->blkGroupSize == 1) + return (false); + + // This is very strong restriction, but needed to avoid some failure + if(pArg->outLengths_[NumOutputDim - 1] % M_BlockTileSize != 0) + return (false); + } + else + { + // cases with very small reduce_total_length should be handled by ThreadWise kernel + if(pArg->reduce_total_length / KThreadSliceSize < 2) + return (false); + }; + + return (true); + }; + + std::unique_ptr MakeArgumentPointer( + const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array, NumReduction> outStridesArray, + const std::array reduceDims, + const std::array alphas, + const std::array betas, + const void* in_dev, + const std::array out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStridesArray, + reduceDims, + alphas, + betas, + in_dev, + out_dev_buffers, + in_elementwise_op_tuple, + acc_elementwise_op_tuple); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceMultipleReduceBlockWise<" : "DeviceMultipleReduceMultiBlock<") << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ","; + str << "OutDstVectorSize"; + static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); }); + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp new file mode 100755 index 000000000000..6d1d5c8e239d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp @@ -0,0 +1,422 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/sequence.hpp" +#include "ck/utility/reduction_operator.hpp" + +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_multiple_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceMultipleReduceThreadWise : public DeviceMultipleReduce +{ + static_assert(Rank <= 6, "Bigger Rank size is not supported!"); + + static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypeTuple::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static_assert(sequence_all_of(OutDstVectorSizeSeq{}, + [](auto vectorSize) { + return (MThreadSliceSize % vectorSize == 0); + }), + "The OutDstVectorSize should completely divide the MThreadSliceSize!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumInputDim = Rank; + static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto GenerateOutDataTypePointerTuple() + { + return generate_tuple( + [&](auto I) { + using DataType = remove_cvref_t; + + return static_cast(nullptr); + }, + Number{}); + }; + + using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple()); + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumOutputDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto GenerateOutGrid1dDescTuple() + { + return generate_tuple( + [&](auto I) { + (void)I; + return MakeDst1dDescriptor(std::array{}, + std::array{}); + }, + Number{}); + }; + + using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(std::array{}, + std::array{})); + using OutGridDesc_M_Tuple = decltype(GenerateOutGrid1dDescTuple()); + + struct Argument : public BaseArgument + { + Argument(const std::array& inLengths, + const std::array& inStrides, + const std::array& outLengths, + const std::array, NumReduction>& outStridesArray, + const std::array& reduceDims, + const std::array& alphas, + const std::array& betas, + const void* in_dev, + const std::array& out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) + : outLengths_{outLengths}, + outStridesArray_{outStridesArray}, + in_elementwise_op_tuple_{in_elementwise_op_tuple}, + acc_elementwise_op_tuple_{acc_elementwise_op_tuple} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + for(size_t i = 0; i < NumReduction; i++) + { + alpha_values_(i) = static_cast(alphas[i]); + beta_values_(i) = static_cast(betas[i]); + }; + + in_dev_ = static_cast(in_dev); + + out_dev_buffers_ = generate_tuple( + [&](auto iR) { + using OutDataTypePointer = + remove_cvref_t; + using OutDataType = remove_cvref_t>; + return static_cast(out_dev_buffers[iR]); + }, + Number{}); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + in_grid_desc_m_k = MakeSrc2dDescriptor(inLengths_, inStrides_); + + out_grid_desc_m_tuple = generate_tuple( + [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); }, + Number{}); + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + } + + std::array inLengths_; + std::array inStrides_; + + std::array outLengths_; + std::array, NumReduction> outStridesArray_; + + Array alpha_values_; + Array beta_values_; + + const InDataType* in_dev_; + OutDataTypePointerTuple out_dev_buffers_; + + InGridDesc_M_K in_grid_desc_m_k; + OutGridDesc_M_Tuple out_grid_desc_m_tuple; + + InElementwiseOperationTuple in_elementwise_op_tuple_; + AccElementwiseOperationTuple acc_elementwise_op_tuple_; + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + using GridwiseMultipleReduce = + GridwiseMultipleReduction_mk_to_m_threadwise; + + const auto kernel_main = + kernel_multiple_reduce_threadwise; + + float avg_time = 0; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + arg.in_grid_desc_m_k, + arg.out_grid_desc_m_tuple, + arg.in_elementwise_op_tuple_, + arg.acc_elementwise_op_tuple_, + arg.alpha_values_, + arg.in_dev_, + arg.beta_values_, + arg.out_dev_buffers_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1) + return (false); + + if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0) + return (false); + }; + + // To improve + bool valid = true; + static_for<0, NumReduction, 1>{}([&](auto I) { + if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 && + OutDstVectorSizeSeq::At(I) != 1) + valid = false; + + if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0) + valid = false; + }); + + if(!valid) + return (false); + + return (true); + }; + + std::unique_ptr MakeArgumentPointer( + const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array, NumReduction> outStridesArray, + const std::array reduceDims, + const std::array alphas, + const std::array betas, + const void* in_dev, + const std::array out_dev_buffers, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStridesArray, + reduceDims, + alphas, + betas, + in_dev, + out_dev_buffers, + in_elementwise_op_tuple, + acc_elementwise_op_tuple); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceMultipleReduceThreadwise<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ","; + str << "OutDstVectorSize"; + static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); }); + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp new file mode 100755 index 000000000000..86689af0b702 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_data_impl.hpp @@ -0,0 +1,465 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { +template +__global__ void +kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M_K dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) +{ + GridwiseNormalizationBwd::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dx_grid_desc_m_k, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_gamma_global, + p_mean_global, + p_inv_std_global, + p_dx_global); +}; + +template +struct DeviceNormalizationBwdDataImpl : public DeviceNormalizationBwdData +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) || + (DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto Make2dDescriptor(const std::vector& lengths, + const std::vector& strides, + int numBlockTileIteration) + { + const auto tupleLengths = make_tuple_from_array(lengths, Number{}); + const auto tupleStrides = make_tuple_from_array(strides, Number{}); + + const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides); + + const auto grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(lengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + + return transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{}); + + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto grid_desc_m_k_padded = + transform_tensor_descriptor(grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, pad_M), + make_right_pad_transform(reduceLength, pad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k_padded; + } + + using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1)); + + using GridwiseNormalizationBwdDataGeneric = + GridwiseNormalizationBwdData_mk_to_mk; + + using GridwiseNormalizationBwdDataSweepOnce = + GridwiseNormalizationBwdData_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const GammaDataType* p_gamma, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DXDataType* p_dx) + : p_dy_(p_dy), + p_x_(p_x), + p_gamma_(p_gamma), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dx_(p_dx) + { + lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + dxStrides_ = shuffle_tensor_dimensions(dxStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(lengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_); + dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_); + + isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize; + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DXDataType* p_dx_; + + std::vector lengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector dxStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // tensor descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + GridDesc_M_K dx_grid_desc_m_k_; + + bool isSweeponce_; + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + auto KernelSelector(bool isSweepOnce) + { + return isSweepOnce + ? kernel_normalization_bwd_data + : kernel_normalization_bwd_data; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = KernelSelector(arg.isSweeponce_); + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dx_grid_desc_m_k_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_gamma_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dx_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dyStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->xStrides_); + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->gammaStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->meanStrides_); + pass &= IsVectorDimSizeValid( + p_arg_->lengths_, p_arg_->invStdStrides_); + + pass &= IsVectorDimSizeValid(p_arg_->lengths_, + p_arg_->dxStrides_); + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector lengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector dxStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_gamma, + const void* p_mean, + const void* p_invStd, + void* p_dx) override + { + if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + gammaStrides.size() != Rank || meanStrides.size() != Rank || + invStdStrides.size() != Rank || dxStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + dyStrides, + xStrides, + gammaStrides, + meanStrides, + invStdStrides, + dxStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dx)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize; + str << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp new file mode 100755 index 000000000000..c35652bfe860 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp @@ -0,0 +1,478 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_operation/gpu/device/device_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +// M is Invariant dimension, K is reduced dimension +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +kernel_normalization_bwd_gamma_beta(const GridDesc_M_K dy_grid_desc_m_k, + const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K mean_grid_desc_m_k, + const GridDesc_M_K inv_std_grid_desc_m_k, + const GridDesc_M dgamma_grid_desc_m, + const GridDesc_M dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) +{ + GridwiseReduction::Run(dy_grid_desc_m_k, + x_grid_desc_m_k, + mean_grid_desc_m_k, + inv_std_grid_desc_m_k, + dgamma_grid_desc_m, + dbeta_grid_desc_m, + num_k_block_tile_iteration, + p_dy_global, + p_x_global, + p_mean_global, + p_inv_std_global, + p_dgamma_global, + p_dbeta_global); +}; + +template +struct DeviceNormalizationBwdGammaBetaImpl + : public DeviceNormalizationBwdGammaBeta +{ + static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; + static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; + static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) || + (DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), + "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " + "check!"); + + static_assert( + ((MThreadSliceSize % DGammaDstVectorSize == 0) || + (MThreadSliceSize % DBetaDstVectorSize == 0)), + "Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please " + "check!"); + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int numBlockTileIteration) + { + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor(inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return in_grid_desc_m_k_padded; + } + + static auto MakeDst1dDescriptor(const std::vector& outLengths, + const std::vector& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeDst1dDescriptor({1}, {1})); + + using GridwiseNormalizationBwdGammaBeta = + GridwiseNormalizationBwdGammaBeta_mk_to_k; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const DYDataType* p_dy, + const XDataType* p_x, + const MeanInvStdDataType* p_mean, + const MeanInvStdDataType* p_invStd, + DGammaDataType* p_dgamma, + DBetaDataType* p_dbeta) + : p_dy_(p_dy), + p_x_(p_x), + p_mean_(p_mean), + p_invStd_(p_invStd), + p_dgamma_(p_dgamma), + p_dbeta_(p_dbeta), + outLengths_{outLengths}, + dgammaStrides_{dgammaStrides}, + dbetaStrides_{dbetaStrides} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + dyStrides_ = shuffle_tensor_dimensions(dyStrides, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + meanStrides_ = shuffle_tensor_dimensions(meanStrides, reduceDims); + invStdStrides_ = + shuffle_tensor_dimensions(invStdStrides, reduceDims); + + std::tie(MRaw_, KRaw_) = get_2d_lengths(inLengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + dy_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, dyStrides_, numBlockTileIteration_); + x_grid_desc_m_k_ = MakeSrc2dDescriptor(inLengths_, xStrides_, numBlockTileIteration_); + mean_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, meanStrides_, numBlockTileIteration_); + inv_std_grid_desc_m_k_ = + MakeSrc2dDescriptor(inLengths_, invStdStrides_, numBlockTileIteration_); + + dgamma_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dgammaStrides_); + dbeta_grid_desc_m_ = MakeDst1dDescriptor(outLengths_, dbetaStrides_); + } + + const DYDataType* p_dy_; + const XDataType* p_x_; + const MeanInvStdDataType* p_mean_; + const MeanInvStdDataType* p_invStd_; + DGammaDataType* p_dgamma_; + DBetaDataType* p_dbeta_; + + std::vector inLengths_; + std::vector dyStrides_; + std::vector xStrides_; + std::vector meanStrides_; + std::vector invStdStrides_; + std::vector outLengths_; + std::vector dgammaStrides_; + std::vector dbetaStrides_; + + int numBlockTileIteration_; + size_t gridSize_; + + // Source descriptor + GridDesc_M_K dy_grid_desc_m_k_; + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K mean_grid_desc_m_k_; + GridDesc_M_K inv_std_grid_desc_m_k_; + + // Destination descriptor + GridDesc_M dgamma_grid_desc_m_; + GridDesc_M dbeta_grid_desc_m_; + + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto kernel_main = + kernel_normalization_bwd_gamma_beta; + + return launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.dy_grid_desc_m_k_, + arg.x_grid_desc_m_k_, + arg.mean_grid_desc_m_k_, + arg.inv_std_grid_desc_m_k_, + arg.dgamma_grid_desc_m_, + arg.dbeta_grid_desc_m_, + arg.numBlockTileIteration_, + arg.p_dy_, + arg.p_x_, + arg.p_mean_, + arg.p_invStd_, + arg.p_dgamma_, + arg.p_dbeta_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + template + bool IsSrcVectorDimSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(SrcVectorSize == 1) + return true; + + // Fastest dimension is not reduced + if constexpr(SrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + return false; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0) + return false; + } + else // Fastest dimension is reduced + { + if(strides[Rank - 1] != 1) + return false; + + if(lengths[Rank - 1] % SrcVectorSize != 0) + return false; + }; + + return true; + } + + template + bool IsDstVectorSizeValid(const std::vector& lengths, + const std::vector& strides) + { + if constexpr(DstVectorSize == 1) + return true; + + if(strides[NumInvariantDim - 1] != 1) + return false; + + if(lengths[NumInvariantDim - 1] % DstVectorSize != 0) + return false; + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + bool pass = true; + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->dyStrides_); + pass &= IsSrcVectorDimSizeValid(p_arg_->inLengths_, + p_arg_->xStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->meanStrides_); + pass &= IsSrcVectorDimSizeValid( + p_arg_->inLengths_, p_arg_->invStdStrides_); + + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dgammaStrides_); + pass &= + IsDstVectorSizeValid(p_arg_->outLengths_, p_arg_->dbetaStrides_); + + return pass; + } + + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector dyStrides, + const std::vector xStrides, + const std::vector meanStrides, + const std::vector invStdStrides, + const std::vector outLengths, + const std::vector dgammaStrides, + const std::vector dbetaStrides, + const std::vector reduceDims, + const void* p_dy, + const void* p_x, + const void* p_mean, + const void* p_invStd, + void* p_dgamma, + void* p_dbeta) override + { + if(inLengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank || + meanStrides.size() != Rank || invStdStrides.size() != Rank) + throw std::runtime_error("dimension is incorrect"); + + if(outLengths.size() != NumInvariantDim || dgammaStrides.size() != NumInvariantDim || + dbetaStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(inLengths, + dyStrides, + xStrides, + meanStrides, + invStdStrides, + outLengths, + dgammaStrides, + dbetaStrides, + reduceDims, + static_cast(p_dy), + static_cast(p_x), + static_cast(p_mean), + static_cast(p_invStd), + static_cast(p_dgamma), + static_cast(p_dbeta)); + } + + virtual std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ; + str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp new file mode 100755 index 000000000000..fa7b9cbda023 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp @@ -0,0 +1,475 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +// M: Invariant length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W +template +struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd +{ + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); // TODO + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int numBlockTileIteration) + { + static constexpr index_t numSrcDim = Rank; + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeSaveMeanInvStdDescriptor_M(const std::vector& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1)); + using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + YElementwiseOperation y_elementwise_op, + double epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) + : p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; + + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); + + numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize); + + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize); + + x_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, xStrides_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, numBlockTileIteration_); + y_grid_desc_m_k_ = MakeSrc2dDescriptor(Lengths_, yStrides_, numBlockTileIteration_); + save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides); + save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides); + + isSweeponce_ = + x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; + } + + ComputeDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + SaveMeanInvStdDataType* p_saveMean_; + SaveMeanInvStdDataType* p_saveInvStd_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; + + YElementwiseOperation y_elementwise_op_; + + int numBlockTileIteration_; + size_t gridSize_; + + GridDesc_M_K x_grid_desc_m_k_; + GridDesc_M_K gamma_grid_desc_m_k_; + GridDesc_M_K beta_grid_desc_m_k_; + GridDesc_M_K y_grid_desc_m_k_; + GridDesc_M save_mean_grid_desc_m_; + GridDesc_M save_inv_std_grid_desc_m_; + bool isSweeponce_; + + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + + index_t invariant_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto kernel_main = NormalizationKernelSelector(arg.isSweeponce_); + + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.save_mean_grid_desc_m_, + arg.save_inv_std_grid_desc_m_, + arg.numBlockTileIteration_, + arg.epsilon_, + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.p_saveMean_, + arg.p_saveInvStd_, + arg.y_elementwise_op_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if constexpr(XYSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + { + return false; + } + }; + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return (false); + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) + return (false); + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return (false); + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return (false); + } + + if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvStd, + YElementwiseOperation y_elementwise_op) override + { + if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank || + betaStrides.size() != Rank || yStrides.size() != Rank || + saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + saveMeanStrides, + saveInvStdStrides, + reduceDims, + y_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationFwdImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYSrcVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp new file mode 100755 index 000000000000..7fe6502bab45 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp @@ -0,0 +1,750 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/device_normalization_fwd.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +template +__global__ void +kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + WorkspaceMeanVarDataType* const __restrict__ p_welford_mean, + WorkspaceMeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count) +{ + GridwiseWelford::Run(x_grid_desc_m_k, + mean_var_grid_desc_m_kblock, + num_k_block_tile_iteration, + p_x_global, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +__global__ void +kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const WorkspaceMeanVarDataType* const p_mean_global, + const WorkspaceMeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock, + count_grid_desc_m_kblock, + x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, + num_k_mean_var_count_iteration, + num_k_block_tile_iteration, + k_grid_size, + epsilon, + p_mean_global, + p_variance_global, + p_welford_count_global, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + p_save_mean_global, + p_save_inv_std_global, + y_elementwise_op); +}; +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Y = Normalization(X, Beta, Gamma) +// M: Invariant length +// K: Reduce length (Calculate mean and variance along K dimension) +// eg. Length = [N, C, H, W], reduce dim = [C, H, W] +// Then, M = N, K = C * H * W +template +struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd +{ + using WorkspaceMeanVarDataType = SaveMeanInvStdDataType; + + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize); + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) || + (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) || + (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)), + "Invalid thread slice sizes and/or beta vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + static_assert(!reduceAllDim); // TODO + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int kBlockSize, + int numBlockTileIteration) + { + static constexpr index_t numSrcDim = Rank; + + const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number{}); + const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, numSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = + make_tuple_from_array_and_index_seq(inLengths, ReduceDims{}); + const auto invariantDimLengths = + make_tuple_from_array_and_index_seq(inLengths, InvariantDims{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + template + static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + template + static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K) + { + const auto grid_desc_m_k = + make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I0, I1)); + return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{}); + } + + static auto MakeSaveMeanInvStdDescriptor_M(const std::vector& lengths, + const std::vector& strides) + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + + const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{}); + const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{}); + + const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto grid_desc_m = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(InvariantDims{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = grid_desc_m.GetLength(Number<0>{}); + const auto pad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto grid_desc_m_padded = transform_tensor_descriptor( + grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, pad_M)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return grid_desc_m_padded; + } + + using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + using Kernel1MeanVarGridDesc_M_KBlock = + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2MeanVarGridDesc_M_KBlock = + decltype(MakeWorkspaceMeanVarDescriptor_M_K, 1, 1>(1, 1)); + + using Kernel2CountGridDesc_M_KBlock = + decltype(MakeWorkspaceCountDescriptor_M_K, 1, 1>(1, 1)); + + using SaveMeanInvStdGridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})); + + using GridwiseWelford = GridwiseNormalizationSplitK1st; + + using GridwiseWelfordNormalization = + GridwiseNormalizationSplitK2nd; + + struct Argument : public BaseArgument + { + Argument(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + YElementwiseOperation y_elementwise_op, + double epsilon, + const XDataType* p_x, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + YDataType* p_y, + SaveMeanInvStdDataType* p_saveMean, + SaveMeanInvStdDataType* p_saveInvStd) + : p_x_(p_x), + p_gamma_(p_gamma), + p_beta_(p_beta), + p_y_(p_y), + p_saveMean_(p_saveMean), + p_saveInvStd_(p_saveInvStd), + p_workspace_mean_{nullptr}, + p_workspace_var_{nullptr}, + p_workspace_count_{nullptr}, + y_elementwise_op_(y_elementwise_op) + { + epsilon_ = static_cast(epsilon); + + Lengths_ = shuffle_tensor_dimensions(lengths, reduceDims); + xStrides_ = shuffle_tensor_dimensions(xStrides, reduceDims); + yStrides_ = shuffle_tensor_dimensions(yStrides, reduceDims); + gammaStrides_ = shuffle_tensor_dimensions(gammaStrides, reduceDims); + betaStrides_ = shuffle_tensor_dimensions(betaStrides, reduceDims); + saveMeanStrides_ = saveMeanStrides; + saveInvStdStrides_ = saveInvStdStrides; + + std::tie(MRaw_, KRaw_) = get_2d_lengths(Lengths_); + + numBlockTileIteration_ = 1; + while(true) + { + int testKGridSize = + math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + + // we want the kGridSize_ be not more than 128 + if(testKGridSize <= 128) + break; + + ++numBlockTileIteration_; + }; + + kGridSize_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize * numBlockTileIteration_); + gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize) * kGridSize_; + + // We do not use vector load for mean, var and count + static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize; + + numMeanVarCountIteration_ = + math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize); + + x_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, xStrides_, kGridSize_, numBlockTileIteration_); + gamma_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, gammaStrides_, kGridSize_, numBlockTileIteration_); + beta_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, betaStrides_, kGridSize_, numBlockTileIteration_); + y_grid_desc_m_k_ = + MakeSrc2dDescriptor(Lengths_, yStrides_, kGridSize_, numBlockTileIteration_); + + save_mean_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveMeanStrides); + save_inv_std_grid_desc_m_ = MakeSaveMeanInvStdDescriptor_M(Lengths_, saveInvStdStrides); + + // We don't need to pad in K dimension for Welford1. Set KPerTile 1. + kernel1_mean_var_grid_desc_m_kblock_ = + MakeWorkspaceMeanVarDescriptor_M_K, M_BlockTileSize, 1>( + MRaw_, kGridSize_); + + kernel2_mean_var_grid_desc_m_kblock_ = + MakeWorkspaceMeanVarDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + kernel2_count_grid_desc_m_kblock_ = + MakeWorkspaceCountDescriptor_M_K, + M_BlockTileSize, + K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = Lengths_[NumInvariantDim - 1]; + } + + ComputeDataType epsilon_; + + const XDataType* p_x_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + YDataType* p_y_; + SaveMeanInvStdDataType* p_saveMean_; + SaveMeanInvStdDataType* p_saveInvStd_; + void* p_workspace_mean_; + void* p_workspace_var_; + void* p_workspace_count_; + + std::vector Lengths_; + std::vector xStrides_; + std::vector gammaStrides_; + std::vector betaStrides_; + std::vector yStrides_; + std::vector saveMeanStrides_; + std::vector saveInvStdStrides_; + + YElementwiseOperation y_elementwise_op_; + + int kGridSize_; + int numMeanVarCountIteration_; + int numBlockTileIteration_; + size_t gridSize_; + + SrcGridDesc_M_K x_grid_desc_m_k_; + SrcGridDesc_M_K gamma_grid_desc_m_k_; + SrcGridDesc_M_K beta_grid_desc_m_k_; + SrcGridDesc_M_K y_grid_desc_m_k_; + SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_; + SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_; + + Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_; + Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; + Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; + + index_t MRaw_; // Invariant length + index_t KRaw_; // reduce length + + index_t invariant_lowest_length_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr || + arg.p_workspace_count_ == nullptr) + throw std::runtime_error("wrong! WorkSpace pointer has not been set"); + + auto kernel1 = kernel_normalizationSplitK1st; + + auto kernel2 = kernel_normalizationSplitK2nd; + + float avg_time = 0; + avg_time += launch_and_time_kernel( + stream_config, + kernel1, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.x_grid_desc_m_k_, + arg.kernel1_mean_var_grid_desc_m_kblock_, + arg.numBlockTileIteration_, + arg.p_x_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_)); + + avg_time += launch_and_time_kernel( + stream_config, + kernel2, + dim3(arg.gridSize_), + dim3(BlockSize), + 0, + arg.kernel2_mean_var_grid_desc_m_kblock_, + arg.kernel2_count_grid_desc_m_kblock_, + arg.x_grid_desc_m_k_, + arg.gamma_grid_desc_m_k_, + arg.beta_grid_desc_m_k_, + arg.y_grid_desc_m_k_, + arg.save_mean_grid_desc_m_, + arg.save_inv_std_grid_desc_m_, + arg.numMeanVarCountIteration_, + arg.numBlockTileIteration_, + arg.kGridSize_, + arg.epsilon_, + static_cast(arg.p_workspace_mean_), + static_cast(arg.p_workspace_var_), + static_cast(arg.p_workspace_count_), + arg.p_x_, + arg.p_gamma_, + arg.p_beta_, + arg.p_y_, + arg.p_saveMean_, + arg.p_saveInvStd_, + arg.y_elementwise_op_); + + return avg_time; + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + size_t GetWorkSpaceSize(const BaseArgument* pArg) const override + { + const Argument* pArg_ = dynamic_cast(pArg); + + size_t workspace_size = 0; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // workspace for welford intermediate mean + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; + + // workspace for welford intermediate variance + workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64; + + // workspace for welford intermediate count + workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64; + + return (workspace_size); + }; + + void SetWorkSpacePointer(BaseArgument* pArg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + Argument* pArg_ = dynamic_cast(pArg); + + pArg_->p_workspace_ = p_workspace; + + int welford_size = pArg_->MRaw_ * pArg_->kGridSize_; + + // setup buffer used for intermediate welford mean + pArg_->p_workspace_mean_ = static_cast(pArg_->p_workspace_); + + index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); + mean_space_sz = math::integer_least_multiple(mean_space_sz, 64); + + // setup buffer used for intermediate welford varirance + pArg_->p_workspace_var_ = reinterpret_cast(pArg_->p_workspace_mean_) + mean_space_sz; + + index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType); + variance_space_sz = math::integer_least_multiple(variance_space_sz, 64); + + // setup buffer used for intermediate welford count + pArg_->p_workspace_count_ = + reinterpret_cast(pArg_->p_workspace_var_) + variance_space_sz; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* p_arg_ = dynamic_cast(p_arg); + + if constexpr(XYVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0) + return false; + + if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0) + return false; + }; + } + else + { + if(p_arg_->xStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0) + return false; + + if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0) + return false; + }; + + // if fastest dim is not reduced + if constexpr(GammaSrcVectorDim == 0) + { + if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->gammaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0) + return false; + } + + // if fastest dim is not reduced + if constexpr(BetaSrcVectorDim == 0) + { + if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1) + return false; + + if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0) + return false; + } + else // if fastest dim is reduced + { + if(p_arg_->betaStrides_[Rank - 1] != 1) + return false; + + if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0) + return false; + } + + if(p_arg_->kGridSize_ <= 1) + return false; + + if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0) + return false; + + return true; + }; + + std::unique_ptr + MakeArgumentPointer(const std::vector lengths, + const std::vector xStrides, + const std::vector gammaStrides, + const std::vector betaStrides, + const std::vector yStrides, + const std::vector saveMeanStrides, + const std::vector saveInvStdStrides, + const std::vector reduceDims, + double epsilon, + const void* p_x, + const void* p_gamma, + const void* p_beta, + void* p_y, + void* p_saveMean, + void* p_saveInvStd, + YElementwiseOperation y_elementwise_op) override + { + if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank || + betaStrides.size() != Rank || yStrides.size() != Rank || + saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim) + throw std::runtime_error("dimension is incorrect"); + + return std::make_unique(lengths, + xStrides, + gammaStrides, + betaStrides, + yStrides, + saveMeanStrides, + saveInvStdStrides, + reduceDims, + y_elementwise_op, + epsilon, + static_cast(p_x), + static_cast(p_gamma), + static_cast(p_beta), + static_cast(p_y), + static_cast(p_saveMean), + static_cast(p_saveInvStd)); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceNormalizationFwdSplitKImpl<" << BlockSize << ","; + str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ","; + str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ","; + str << "XYSrcVectorDim_" << XYVectorDim << ","; + str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp new file mode 100755 index 000000000000..17dab0833216 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp @@ -0,0 +1,282 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/math.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_permute.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_permute.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// Swap last 2 dimensions +// input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]] +// ^^^^^^^^^^^ +// output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]] +// ^^^^^^^^^^^ +template +struct DevicePermuteImpl : DevicePermute +{ + using BaseType = DevicePermute; + using typename BaseType::Lengths; + using typename BaseType::Strides; + + static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor"); + static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim); + static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim); + static_assert(SrcVectorDim != DstVectorDim); + + template + static auto ConvertArrayToTuple(const std::array& array) + { + static_assert(1 <= N && N <= NumDim); + + return generate_tuple([&](auto I) { return array[I]; }, Number{}); + } + + static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride) + { + // create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], + // d[NumDim-1]] + const auto desc = + make_naive_tensor_descriptor(ConvertArrayToTuple(lengths), ConvertArrayToTuple(stride)); + + // merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2], + // d[NumDim-1]] + // => [N, H, W] + const index_t H = *std::next(rbegin(lengths)); + const index_t W = *rbegin(lengths); + const auto desc_n_h_w = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(ConvertArrayToTuple(lengths)), + make_pass_through_transform(H), + make_pass_through_transform(W)), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{}), + Sequence{}, + Sequence{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return PadTensorDescriptor( + desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence{}); + } + + using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})); + using OutGridDesc = InGridDesc; + + using GridwisePermute = GridwisePermute< + InGridDesc, + OutGridDesc, + InDataType, + OutDataType, + ElementwiseOperation, + BlockSize, + NPerBlock, + HPerBlock, + WPerBlock, + InBlockLdsExtraW, + InBlockTransferThreadClusterLengths, + InBlockTransferThreadClusterArrangeOrder, + SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor + DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor + SrcScalarPerVector, + DstScalarPerVector>; + + using Block2TileMap = typename GridwisePermute::DefaultBlock2TileMap; + + struct Argument : public BaseArgument + { + Argument(const Lengths& in_lengths, + const Strides& in_strides, + const Lengths& out_lengths, + const Strides& out_strides, + const void* in_dev_buffer, + void* out_dev_buffer, + ElementwiseOperation elementwise_op) + : in_dev_buffer_(static_cast(in_dev_buffer)), + out_dev_buffer_(static_cast(out_dev_buffer)), + in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)), + out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)), + in_lengths_(in_lengths), + in_strides_(in_strides), + out_lengths_(out_lengths), + out_strides_(out_strides), + elementwise_op_(elementwise_op), + block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_)) + { + } + + const InDataType* in_dev_buffer_; + OutDataType* out_dev_buffer_; + InGridDesc in_grid_desc_; + OutGridDesc out_grid_desc_; + + Lengths in_lengths_; + Strides in_strides_; + Lengths out_lengths_; + Strides out_strides_; + + ElementwiseOperation elementwise_op_; + + Block2TileMap block_2_tile_map_; + }; + + struct Invoker : BaseInvoker + { + static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_); + + const auto kernel = kernel_nd_permute; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.in_grid_desc_, + arg.out_grid_desc_, + arg.in_dev_buffer_, + arg.out_dev_buffer_, + arg.elementwise_op_, + arg.block_2_tile_map_); + return elapsed_time; + } + + float Run(const BaseArgument* arg, + const StreamConfig& stream_config = StreamConfig{}) override final + { + const auto* const argument = dynamic_cast(arg); + if(!argument) + { + return NAN; + } + + return Run(*argument, stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) { + return math::integer_divide_ceil(length, tile_length) * tile_length; + }; + + constexpr auto IsScalarPerVectorValid = + [](index_t length, index_t stride, index_t scalar_per_vector) { + if(stride == 1 && length % scalar_per_vector == 0) + { + return true; + } + else if(stride != 1 && scalar_per_vector == 1) + { + return true; + } + + return false; + }; + + return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim], + arg.in_strides_[SrcVectorDim], + SrcScalarPerVector) && + IsScalarPerVectorValid( + GetPaddedLength(arg.in_lengths_[SrcVectorDim], + (SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)), + arg.in_strides_[SrcVectorDim], + SrcScalarPerVector) && + IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim], + arg.out_strides_[DstVectorDim], + DstScalarPerVector) && + IsScalarPerVectorValid( + GetPaddedLength(arg.out_lengths_[DstVectorDim], + (DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)), + arg.in_strides_[DstVectorDim], + DstScalarPerVector) && + GridwisePermute::CheckValidity(arg.in_grid_desc_, arg.out_grid_desc_); + }; + + // override methods inherited from 'BaseOperator' + bool IsSupportedArgument(const BaseArgument* arg) override final + { + const auto* const argument = dynamic_cast(arg); + if(!argument) + { + return false; + } + + return IsSupportedArgument(*argument); + } + + // override methods inherited from 'DevicePermute' + std::unique_ptr + MakeArgumentPointer(const Lengths& in_lengths, + const Strides& in_strides, + const Lengths& out_lengths, + const Strides& out_strides, + const void* in_dev_buffer, + void* out_dev_buffer, + ElementwiseOperation elementwise_op) override final + { + return std::make_unique(in_lengths, + in_strides, + out_lengths, + out_strides, + in_dev_buffer, + out_dev_buffer, + elementwise_op); + } + + std::unique_ptr MakeInvokerPointer() override final + { + return std::make_unique(); + }; + + // other constructor methods + template + static std::enable_if_t, Argument> + MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v) + { + return Argument{std::forward(args)...}; + } + + static std::enable_if_t, Invoker> + MakeInvoker() noexcept(std::is_nothrow_default_constructible_v) + { + return Invoker{}; + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp new file mode 100755 index 000000000000..fcb9b8a90346 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp @@ -0,0 +1,388 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool2dFwd_NHWC_NHWC : public DevicePoolFwd<4, + 2, + InDataType, + OutDataType, + IndexDataType, + tensor_layout::convolution::NHWC, + tensor_layout::convolution::NHWC, + ReduceOpId, + OutputIndex> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t InOutRank = 4; + static constexpr index_t WindowRank = 2; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(std::vector input_nchw_lengths, + std::vector output_nchw_lengths, + std::vector input_nchw_stride, + std::vector output_nchw_stride, + std::vector window_spatial_yx_lengths, + std::vector window_yx_strides, + std::vector window_yx_dilations, + std::vector input_left_hw_pads, + std::vector input_right_hw_pads) + { + const index_t N = input_nchw_lengths[0]; + const index_t C = input_nchw_lengths[1]; + const index_t Hi = input_nchw_lengths[2]; + const index_t Wi = input_nchw_lengths[3]; + + const index_t Ho = output_nchw_lengths[2]; + const index_t Wo = output_nchw_lengths[3]; + const index_t Y = window_spatial_yx_lengths[0]; + const index_t X = window_spatial_yx_lengths[1]; + + const index_t WindowStrideH = window_yx_strides[0]; + const index_t WindowStrideW = window_yx_strides[1]; + + const index_t WindowDilationH = window_yx_dilations[0]; + const index_t WindowDilationW = window_yx_dilations[1]; + + const index_t InLeftPadH = input_left_hw_pads[0]; + const index_t InLeftPadW = input_left_hw_pads[1]; + + const index_t InRightPadH = input_right_hw_pads[0]; + const index_t InRightPadW = input_right_hw_pads[1]; + + const index_t MRaw = N * Ho * Wo * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = Y * X; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // A[ReduceM, ReduceK] + const index_t Ni_stride = input_nchw_stride[0]; + const index_t Ci_stride = input_nchw_stride[1]; + const index_t Hi_stride = input_nchw_stride[2]; + const index_t Wi_stride = input_nchw_stride[3]; + + const auto in_grid_desc_n_hi_wi_c = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(Ni_stride, Hi_stride, Wi_stride, Ci_stride)); + + const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_hip_wip_c, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_grid_desc_reducemraw_reducekraw = + transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C)), + make_merge_transform(make_tuple(Y, X))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const index_t No_stride = output_nchw_stride[0]; + const index_t Co_stride = output_nchw_stride[1]; + const index_t Ho_stride = output_nchw_stride[2]; + const index_t Wo_stride = output_nchw_stride[3]; + + const auto out_grid_desc_n_ho_wo_c = make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, C), make_tuple(No_stride, Ho_stride, Wo_stride, Co_stride)); + + const auto out_grid_desc_reducemraw = + transform_tensor_descriptor(out_grid_desc_n_ho_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Ho, Wo, C))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto out_grid_desc_reducem = + transform_tensor_descriptor(out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = + decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})); + + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + IndexDataType* p_out_indices_dev, + std::vector& input_nchw_lengths, + std::vector& output_nchw_lengths, + std::vector& input_nchw_stride, + std::vector& output_nchw_stride, + std::vector&, // indices_nchw_stride + std::vector& window_spatial_yx_lengths, + std::vector& window_yx_strides, + std::vector& window_yx_dilations, + std::vector& input_left_hw_pads, + std::vector& input_right_hw_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{}, + input_nchw_lengths_{input_nchw_lengths}, + output_nchw_lengths_{output_nchw_lengths}, + input_nchw_stride_{input_nchw_stride}, + output_nchw_stride_{output_nchw_stride} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_nchw_lengths, + output_nchw_lengths, + input_nchw_stride, + output_nchw_stride, + window_spatial_yx_lengths, + window_yx_strides, + window_yx_dilations, + input_left_hw_pads, + input_right_hw_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + int32_t reduceLength = window_spatial_yx_lengths[0] * window_spatial_yx_lengths[1]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + IndexDataType* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + std::vector input_nchw_lengths_; + std::vector output_nchw_lengths_; + std::vector input_nchw_stride_; + std::vector output_nchw_stride_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + // for NHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K + + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = + kernel_reduce_threadwise; + + ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (M / M_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + // C should be fastest dimension + if(pArg->input_nchw_stride_[1] != 1) + return false; + + for(int i = 0; i < InOutRank; ++i) + { + if(pArg->input_nchw_stride_[i] == 1 && + pArg->input_nchw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + + if(pArg->output_nchw_stride_[i] == 1 && + pArg->output_nchw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + std::vector input_nchw_lengths, + std::vector window_yx_lengths, + std::vector output_nchw_lengths, + std::vector input_nchw_stride, + std::vector output_nchw_stride, + std::vector indices_nchw_stride, + std::vector window_yx_strides, + std::vector window_yx_dilations, + std::vector input_left_hw_pads, + std::vector input_right_hw_pads, + std::vector pooling_dims) override + { + if(input_nchw_lengths.size() != InOutRank || window_yx_lengths.size() != WindowRank || + input_nchw_lengths.size() != InOutRank || window_yx_strides.size() != WindowRank || + window_yx_dilations.size() != WindowRank || input_left_hw_pads.size() != WindowRank || + input_right_hw_pads.size() != WindowRank) + throw std::runtime_error("dimension is incorrect"); + + if(pooling_dims != std::vector{2, 3}) + throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far"); + + if(output_nchw_stride != indices_nchw_stride) + throw std::runtime_error( + "output_nchw_stride need to be equal to indices_nchw_stride for now"); + + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + input_nchw_lengths, + output_nchw_lengths, + input_nchw_stride, + output_nchw_stride, + indices_nchw_stride, + window_yx_lengths, + window_yx_strides, + window_yx_dilations, + input_left_hw_pads, + input_right_hw_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool2dFwd_NHWC_NHWC<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp new file mode 100755 index 000000000000..384805a0a3de --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp @@ -0,0 +1,411 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp" +#include "ck/tensor_operation/gpu/device/device_pool_fwd.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DevicePool3dFwd_NDHWC_NDHWC : public DevicePoolFwd<5, + 3, + InDataType, + OutDataType, + IndexDataType, + tensor_layout::convolution::NDHWC, + tensor_layout::convolution::NDHWC, + ReduceOpId, + OutputIndex> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr index_t InOutRank = 5; + static constexpr index_t WindowRank = 3; + + using ReduceOperation = typename reduce_binary_operator::opType; + + using InElementwiseOperation = + typename reduce_unary_operator::InElementwiseOperation; + + using AccElementwiseOperation = + typename reduce_unary_operator::AccElementwiseOperation; + + static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeABGridDescriptor_A_M_K_B_M(std::vector input_ncdhw_lengths, + std::vector output_ncdhw_lengths, + std::vector input_ncdhw_stride, + std::vector output_ncdhw_stride, + std::vector window_spatial_zyx_lengths, + std::vector window_zyx_strides, + std::vector window_zyx_dilations, + std::vector input_left_dhw_pads, + std::vector input_right_dhw_pads) + { + const index_t N = input_ncdhw_lengths[0]; + const index_t C = input_ncdhw_lengths[1]; + const index_t Di = input_ncdhw_lengths[2]; + const index_t Hi = input_ncdhw_lengths[3]; + const index_t Wi = input_ncdhw_lengths[4]; + + const index_t Do = output_ncdhw_lengths[2]; + const index_t Ho = output_ncdhw_lengths[3]; + const index_t Wo = output_ncdhw_lengths[4]; + + const index_t Z = window_spatial_zyx_lengths[0]; + const index_t Y = window_spatial_zyx_lengths[1]; + const index_t X = window_spatial_zyx_lengths[2]; + + const index_t WindowStrideD = window_zyx_strides[0]; + const index_t WindowStrideH = window_zyx_strides[1]; + const index_t WindowStrideW = window_zyx_strides[2]; + + const index_t WindowDilationD = window_zyx_dilations[0]; + const index_t WindowDilationH = window_zyx_dilations[1]; + const index_t WindowDilationW = window_zyx_dilations[2]; + + const index_t InLeftPadD = input_left_dhw_pads[0]; + const index_t InLeftPadH = input_left_dhw_pads[1]; + const index_t InLeftPadW = input_left_dhw_pads[2]; + + const index_t InRightPadD = input_right_dhw_pads[0]; + const index_t InRightPadH = input_right_dhw_pads[1]; + const index_t InRightPadW = input_right_dhw_pads[2]; + + const index_t MRaw = N * Do * Ho * Wo * C; + const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; + + const index_t KRaw = Z * Y * X; + const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw; + + // A[ReduceM, ReduceK] + const index_t Ni_stride = input_ncdhw_stride[0]; + const index_t Ci_stride = input_ncdhw_stride[1]; + const index_t Di_stride = input_ncdhw_stride[2]; + const index_t Hi_stride = input_ncdhw_stride[3]; + const index_t Wi_stride = input_ncdhw_stride[4]; + + const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride)); + + const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor( + in_grid_desc_n_di_hi_wi_c, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor( + in_grid_desc_n_dip_hip_wip_c, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor( + in_grid_desc_n_z_do_y_ho_x_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)), + make_merge_transform(make_tuple(Z, Y, X))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor( + in_grid_desc_reducemraw_reducekraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // B[ReduceM] + const index_t No_stride = output_ncdhw_stride[0]; + const index_t Co_stride = output_ncdhw_stride[1]; + const index_t Do_stride = output_ncdhw_stride[2]; + const index_t Ho_stride = output_ncdhw_stride[3]; + const index_t Wo_stride = output_ncdhw_stride[4]; + + const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride)); + + const auto out_grid_desc_reducemraw = transform_tensor_descriptor( + out_grid_desc_n_do_ho_wo_c, + make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto out_grid_desc_reducem = + transform_tensor_descriptor(out_grid_desc_reducemraw, + make_tuple(make_right_pad_transform(MRaw, MPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + + return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); + } + + using ABGridDescs = + decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})); + + using AGridDesc_M_K = remove_cvref_t; + using BGridDesc_M = remove_cvref_t; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_in_dev, + OutDataType* p_out_dev, + IndexDataType* p_out_indices_dev, + std::vector& input_ncdhw_lengths, + std::vector& output_ncdhw_lengths, + std::vector& input_ncdhw_stride, + std::vector& output_ncdhw_stride, + std::vector&, // indices_ncdhw_stride + std::vector& window_spatial_zyx_lengths, + std::vector& window_zyx_strides, + std::vector& window_zyx_dilations, + std::vector& input_left_dhw_pads, + std::vector& input_right_dhw_pads) + : p_in_dev_{p_in_dev}, + p_out_dev_{p_out_dev}, + p_out_indices_dev_{p_out_indices_dev}, + a_grid_desc_m_k_{}, + b_grid_desc_m_{}, + input_ncdhw_lengths_{input_ncdhw_lengths}, + output_ncdhw_lengths_{output_ncdhw_lengths}, + input_ncdhw_stride_{input_ncdhw_stride}, + output_ncdhw_stride_{output_ncdhw_stride} + { + const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths, + output_ncdhw_lengths, + input_ncdhw_stride, + output_ncdhw_stride, + window_spatial_zyx_lengths, + window_zyx_strides, + window_zyx_dilations, + input_left_dhw_pads, + input_right_dhw_pads); + + a_grid_desc_m_k_ = descs[I0]; + b_grid_desc_m_ = descs[I1]; + + int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] * + window_spatial_zyx_lengths[2]; + + std::tie(in_element_op_, acc_element_op_) = + reduce_unary_operator::GetElementwiseOperator(reduceLength); + } + + const InDataType* p_in_dev_; + OutDataType* p_out_dev_; + IndexDataType* p_out_indices_dev_; + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_M b_grid_desc_m_; + + InElementwiseOperation in_element_op_; + AccElementwiseOperation acc_element_op_; + + // for checking vector load/store + std::vector input_ncdhw_lengths_; + std::vector output_ncdhw_lengths_; + std::vector input_ncdhw_stride_; + std::vector output_ncdhw_stride_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + // for NDHWC, the dim C is the fastest dimension, and is not reduced. + // Hence, it is in M dimension for reduction kernel. + static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K + + using gridwise_reduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = + kernel_reduce_threadwise; + + ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0); + + const index_t grid_size = (M / M_BlockTileSize); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.a_grid_desc_m_k_, + arg.b_grid_desc_m_, + arg.in_element_op_, + arg.acc_element_op_, + float(1), + arg.p_in_dev_, + nullptr, + float(0), + arg.p_out_dev_, + arg.p_out_indices_dev_); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + // C should be fastest dimension + if(pArg->input_ncdhw_stride_[1] != 1) + return false; + + for(int i = 0; i < InOutRank; ++i) + { + if(pArg->input_ncdhw_stride_[i] == 1 && + pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + + if(pArg->output_ncdhw_stride_[i] == 1 && + pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0) + return false; + } + + return true; + } + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_in_dev, + void* p_out_dev, + void* p_out_indices_dev, + std::vector input_ncdhw_lengths, + std::vector window_zyx_lengths, + std::vector output_ncdhw_lengths, + std::vector input_ncdhw_stride, + std::vector output_ncdhw_stride, + std::vector indices_ncdhw_stride, + std::vector window_zyx_strides, + std::vector window_zyx_dilations, + std::vector input_left_dhw_pads, + std::vector input_right_dhw_pads, + std::vector pooling_dims) override + { + if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank || + input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank || + window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank || + input_right_dhw_pads.size() != WindowRank) + throw std::runtime_error("dimension is incorrect"); + + if(pooling_dims != std::vector{2, 3, 4}) + throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far"); + + if(output_ncdhw_stride != indices_ncdhw_stride) + throw std::runtime_error( + "output_ncdhw_stride need to be equal to indices_ncdhw_stride for now"); + + return std::make_unique(static_cast(p_in_dev), + static_cast(p_out_dev), + static_cast(p_out_indices_dev), + input_ncdhw_lengths, + output_ncdhw_lengths, + input_ncdhw_stride, + output_ncdhw_stride, + indices_ncdhw_stride, + window_zyx_lengths, + window_zyx_strides, + window_zyx_dilations, + input_left_dhw_pads, + input_right_dhw_pads); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp new file mode 100755 index 000000000000..7334da0e35af --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_put_element.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/stream_utility.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// output[indices] = input +template +struct DevicePutElementImpl + : public DevicePutElement +{ + template + static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize) + { + constexpr auto I0 = Number<0>{}; + + const auto m = desc_m.GetLength(I0); + const index_t loop_step = gridSize * blockSize * InVectorSize; + const auto pad = math::integer_least_multiple(m, loop_step) - m; + const auto desc_m_pad = + transform_tensor_descriptor(desc_m, + make_tuple(make_right_pad_transform(m, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return desc_m_pad; + } + + static auto MakeDescriptor_M(index_t length, index_t gridSize, index_t blockSize) + { + const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length)); + return PadDescriptor_M_1d(desc_m, gridSize, blockSize); + } + + using InGrid1dDesc = decltype(MakeDescriptor_M(1, 1, 1)); + + using GridwisePutElement = GridwisePutElement_1D; + + struct Argument : public BaseArgument + { + Argument(const InDataType* p_input, + const IndexDataType* p_indices, + OutDataType* p_output, + index_t input_length, + ElementwiseOperation elementwise_op) + : p_input_{p_input}, + p_indices_{p_indices}, + p_output_{p_output}, + input_length_raw_{input_length}, + elementwise_op_{elementwise_op}, + blockSize_{256} + { + } + + const InDataType* p_input_; + const IndexDataType* p_indices_; + OutDataType* p_output_; + index_t input_length_raw_; + ElementwiseOperation elementwise_op_; + index_t blockSize_; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + index_t gridSize = getAvailableComputeUnitCount(stream_config); + InGrid1dDesc in_grid_desc = + MakeDescriptor_M(arg.input_length_raw_, gridSize, arg.blockSize_); + + const auto kernel = kernel_put_element_1d; + + float elapsed_time = launch_and_time_kernel(stream_config, + kernel, + dim3(gridSize), + dim3(arg.blockSize_), + 0, + in_grid_desc, + arg.p_input_, + arg.p_indices_, + arg.p_output_, + arg.elementwise_op_); + return elapsed_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if(pArg->input_length_raw_ % InVectorSize != 0) + { + return false; + } + return true; + } + + std::unique_ptr MakeArgumentPointer(const void* p_input, + const void* p_indices, + void* p_output, + index_t input_length, + index_t, + ElementwiseOperation elementwise_op) override + { + return std::make_unique(static_cast(p_input), + static_cast(p_indices), + static_cast(p_output), + input_length, + elementwise_op); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp new file mode 100755 index 000000000000..67956d9f3f0e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/utility/reduction_operator.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those +// of reduce dims +template +std::pair get_2d_lengths(const std::vector& inLengths) +{ + static_assert(Rank <= 12, "bigger Rank size not supported!"); + + long_index_t invariant_total_length = 1; + long_index_t reduce_total_length = 1; + + constexpr int NumInvariantDim = Rank - NumReduceDim; + + for(int i = NumInvariantDim; i < Rank; i++) + reduce_total_length *= inLengths[i]; + + for(int i = 0; i < NumInvariantDim; i++) + invariant_total_length *= inLengths[i]; + + return std::make_pair(invariant_total_length, reduce_total_length); +}; + +template +std::pair get_2d_lengths(const std::array& inLengths) +{ + static_assert(Rank <= 12, "bigger Rank size not supported!"); + + long_index_t invariant_total_length = 1; + long_index_t reduce_total_length = 1; + + constexpr int NumInvariantDim = Rank - NumReduceDim; + + for(int i = NumInvariantDim; i < Rank; i++) + reduce_total_length *= inLengths[i]; + + for(int i = 0; i < NumInvariantDim; i++) + invariant_total_length *= inLengths[i]; + + return std::make_pair(invariant_total_length, reduce_total_length); +}; + +// helper functions using variadic template arguments +template +auto make_tuple_from_array_and_index_seq(const std::vector& lengths, Sequence) +{ + return make_tuple(static_cast(lengths[Ns])...); +}; + +template +auto make_tuple_from_array(const std::vector& lengths, Number) +{ + static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions"); + + constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{}; + + return make_tuple_from_array_and_index_seq(lengths, index_seq); +}; + +template +std::vector shuffle_tensor_dimensions(const std::vector& origLengthsStrides, + const std::vector& reduceDims) +{ + std::vector newLengthsStrides; + + assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size()); + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + // collect invariant dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + // collect reduce dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) > 0) + { + newLengthsStrides.push_back(origLengthsStrides[i]); + }; + + return newLengthsStrides; +}; + +template +std::array +shuffle_tensor_dimensions(const std::array& origLengthsStrides, + const std::array& reduceDims) +{ + std::array newLengthsStrides; + + int reduceFlag = 0; + + // flag the bits for the reduceDims + for(int i = 0; i < NumReduceDim; i++) + { + reduceFlag |= 1 << reduceDims[i]; + }; + + // collect invariant dimensions + int pos = 0; + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) == 0) + { + newLengthsStrides[pos++] = origLengthsStrides[i]; + }; + + // collect reduce dimensions + for(int i = 0; i < Rank; i++) + if((reduceFlag & (1 << i)) > 0) + { + newLengthsStrides[pos++] = origLengthsStrides[i]; + }; + + return newLengthsStrides; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp new file mode 100755 index 000000000000..b4873e3403ce --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp @@ -0,0 +1,551 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceMultiBlock : public DeviceReduce +{ + static_assert(Rank <= 12, "Bigger Rank size is not supported!"); + static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize, + "Invalid thread cluster size assignments!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added + // later + static constexpr bool use_multiblock = + (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd); + + static_assert(ck::reduce::InMemoryDataOperationSupportedOnDataType::value, + "The OutDataType must support the specified OutMemoryDataOperation!"); + + static_assert(!use_multiblock || (use_multiblock && !OutputIndex), + "MultiBlock reduction can only be used when outputing index is not required"); + + static_assert( + ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation), + "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!"); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto MakeDst1dDescriptorForBufferSet(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto length = out_grid_desc_m.GetLength(Number<0>{}); + + const auto pad = math::integer_least_multiple(length, BlockSize) - length; + + auto out_grid_desc_m_padded = + transform_tensor_descriptor(out_grid_desc_m, + make_tuple(make_right_pad_transform(length, pad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + double alpha, + double beta, + const InDataType* in_dev, + const IndexDataType* in_index_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + in_index_dev_{in_index_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + if(Rank != inLengths.size() || Rank != inStrides.size() || + NumReduceDim != reduceDims.size()) + { + throw std::runtime_error( + "One of inLengths/inStrides/reduceDims has invalid size!" + "\nExpected size inLengths: " + + std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) + + ", reduceDims: " + std::to_string(NumReduceDim) + + "\nBut have inLengths: " + std::to_string(inLengths.size()) + + ", inStrides: " + std::to_string(inStrides.size()) + + ", reduceDims: " + std::to_string(reduceDims.size())); + } + + for(std::size_t i = 0; i < reduceDims.size(); ++i) + { + if(reduceDims[i] < 0 || reduceDims[i] >= Rank) + { + throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!" + "\nHave reduceDims[" + + std::to_string(i) + + "]: " + std::to_string(reduceDims[i])); + } + } + + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + if constexpr(use_multiblock) + { + + int iterations = 1; + while(true) + { + int testBlkGroupSize = + (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + // we want the blkGroupSize be not more than 128 + if(testBlkGroupSize <= 128) + break; + + iterations++; + }; + + blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) / + (K_BlockTileSize * iterations); + + numBlockTileIteration = iterations; + } + else + { + blkGroupSize = 1; + numBlockTileIteration = + (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + + gridSize_pre = + math::integer_least_multiple(invariant_total_length, BlockSize) / BlockSize; + } + + std::array inLengths_; + std::array inStrides_; + std::array outLengths_; + std::array outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + const IndexDataType* in_index_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + + size_t gridSize_pre; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m = + DeviceReduceMultiBlock::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet( + arg.outLengths_, arg.outStrides_); + + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + using OutGridDesc_M_2 = decltype(out_grid_desc_m_2); + + using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock; + + const auto kernel_main = kernel_reduce_multiblock; + + float avg_time = 0; + + if constexpr(use_multiblock) + { + const auto identityVal = + ck::reduce::GetIdentityValueForInMemoryDataOperation( + OutMemoryDataOperation); + + const auto kernel_pre = + kernel_buffer_set_value; + + avg_time += launch_and_time_kernel(stream_config, + kernel_pre, + dim3(arg.gridSize_pre), + dim3(BlockSize), + 0, + out_grid_desc_m_2, + arg.out_dev_, + identityVal); + }; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.in_index_dev_, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument* pArg) + { + if constexpr(use_multiblock) + { + if(static_cast(pArg->beta_) != 0.0f) + return (false); + }; + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + if constexpr(use_multiblock) + { + // blkGroupSize of 1 should be handled by Blockwise path using + // InMemoryDataOperationEnum::Set + if(pArg->blkGroupSize == 1) + return (false); + + // This is very strong restriction, but needed to avoid some failure + if(pArg->invariant_lowest_length % M_BlockTileSize != 0) + return (false); + } + else + { + // cases with very small reduce_total_length should be handled by ThreadWise kernel + // if(pArg->reduce_total_length / KThreadSliceSize < 2) + // return (false); + }; + + return (true); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + double alpha, + double beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(in_index_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ","; + str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp new file mode 100755 index 000000000000..8291575fb82c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp @@ -0,0 +1,392 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceThreadWise : public DeviceReduce + +{ + static_assert(Rank <= 12, "Bigger Rank size is not supported!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + struct Argument : public BaseArgument + { + Argument(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + double alpha, + double beta, + const InDataType* in_dev, + OutDataType* out_dev, + IndexDataType* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) + : outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + out_dev_{out_dev}, + out_index_dev_{out_index_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + alpha_ = type_convert(alpha); + beta_ = type_convert(beta); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + } + + std::array inLengths_; + std::array inStrides_; + std::array outLengths_; + std::array outStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + IndexDataType* out_index_dev_; + + InElementwiseOperation in_elementwise_op_; + AccElementwiseOperation acc_elementwise_op_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = + DeviceReduceThreadWise::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); + const auto out_grid_desc_m = + DeviceReduceThreadWise::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + using InGridDesc_M_K = decltype(in_grid_desc_m_k); + using OutGridDesc_M = decltype(out_grid_desc_m); + + float avg_time = 0; + + using GridwiseReduce = + GridwiseReduction_mk_to_m_threadwise; + + const auto kernel = kernel_reduce_threadwise; + + avg_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.acc_elementwise_op_, + arg.alpha_, + arg.in_dev_, + nullptr, + arg.beta_, + arg.out_dev_, + arg.out_index_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + // cases with big reduce_total_length should be handled by Blockwise kernel + if(pArg->reduce_total_length / KThreadSliceSize >= 32) + return (false); + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + double alpha, + double beta, + const void* in_dev, + const void* in_index_dev, + void* out_dev, + void* out_index_dev, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op) override + { + (void)in_index_dev; + + return std::make_unique(inLengths, + inStrides, + outLengths, + outStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(out_dev), + static_cast(out_index_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceThreadWise<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp new file mode 100755 index 000000000000..764b9312f358 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp @@ -0,0 +1,412 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_reduce_multi_d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" + +#include "ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD + +{ + static_assert(Rank <= 12, "Bigger Rank size is not supported!"); + + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using IndexDataType = int32_t; + + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::array& inLengths, + const std::array& inStrides) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = + math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + static auto MakeDst1dDescriptor(const std::array& outLengths, + const std::array& outStrides) + { + const auto tupleDstLengths = + generate_tuple([&](auto I) { return outLengths[I]; }, Number{}); + const auto tupleDstStrides = + generate_tuple([&](auto I) { return outStrides[I]; }, Number{}); + + auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + auto out_grid_desc_m = transform_tensor_descriptor( + outDesc, + make_tuple(make_merge_transform(tupleDstLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumDstDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{}); + + const auto outPad = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + + auto out_grid_desc_m_padded = transform_tensor_descriptor( + out_grid_desc_m, + make_tuple(make_right_pad_transform(invariantLength, outPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + return (out_grid_desc_m_padded); + }; + + static auto + MakeDsDescriptor(const std::array, NumDTensor> DsLengths, + std::array, NumDTensor> DsStrides) + { + return generate_tuple( + [&](auto i) { + return DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(DsLengths[i], + DsStrides[i]); + }, + Number{}); + } + + using InGridDesc_M_K = decltype(MakeSrc2dDescriptor({}, {})); + using OutGridDesc_M = decltype(MakeDst1dDescriptor({}, {})); + using DsGridDesc_M = decltype(MakeDsDescriptor({}, {})); + + using GridwiseReduce = + GridwiseReduction_mk_to_m_threadwise_multi_d; + + using DsGridPointer = typename GridwiseReduce::DsGridPointer; + + struct Argument : public BaseArgument + { + Argument(const std::array inLengths, + const std::array inStrides, + const std::array, NumDTensor> DsLengths, + const std::array, NumDTensor> DsStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + const InDataType* in_dev, + const std::array ds_dev, + OutDataType* out_dev, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op) + : DsLengths_{DsLengths}, + DsStrides_{DsStrides}, + outLengths_{outLengths}, + outStrides_{outStrides}, + in_dev_{in_dev}, + out_dev_{out_dev}, + in_elementwise_op_{in_elementwise_op}, + out_elementwise_op_{out_elementwise_op} + { + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length = 1; + else + invariant_lowest_length = inLengths_[NumInvariantDim - 1]; + + reduce_lowest_length = inLengths_[Rank - 1]; + + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid_(i) = static_cast(ds_dev[i]); + }); + + ds_grid_desc_m_ = MakeDsDescriptor(DsLengths, DsStrides); + } + + std::array inLengths_; + std::array inStrides_; + + std::array, NumDTensor> DsLengths_; + std::array, NumDTensor> DsStrides_; + + std::array outLengths_; + std::array outStrides_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + + DsGridPointer p_ds_grid_; + + InElementwiseOperation in_elementwise_op_; + OutElementwiseOperation out_elementwise_op_; + + DsGridDesc_M ds_grid_desc_m_; + + index_t invariant_lowest_length; + index_t reduce_lowest_length; + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = + DeviceReduceThreadWiseMultiD::MakeSrc2dDescriptor(arg.inLengths_, arg.inStrides_); + const auto out_grid_desc_m = + DeviceReduceThreadWiseMultiD::MakeDst1dDescriptor(arg.outLengths_, arg.outStrides_); + + float avg_time = 0; + + const auto kernel = kernel_reduce_threadwise_multi_d; + + avg_time = launch_and_time_kernel(stream_config, + kernel, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + arg.ds_grid_desc_m_, + out_grid_desc_m, + arg.in_elementwise_op_, + arg.out_elementwise_op_, + arg.in_dev_, + arg.p_ds_grid_, + arg.out_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + const Argument* pArg = dynamic_cast(p_arg); + + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return (false); + } + else + { + if(pArg->inStrides_[NumInvariantDim - 1] != 1) + return (false); + + if(pArg->invariant_lowest_length % InSrcVectorSize != 0) + return (false); + }; + } + else + { + if(pArg->inStrides_[Rank - 1] != 1) + return (false); + + if(pArg->reduce_lowest_length % InSrcVectorSize != 0) + return (false); + }; + + // To improve + if(pArg->invariant_lowest_length % OutDstVectorSize != 0) + return (false); + + std::cerr << "reduce_total_length = " << pArg->reduce_total_length + << " KThreadSliceSize = " << KThreadSliceSize << std::endl; + + // cases with big reduce_total_length should be handled by Blockwise kernel + if(pArg->reduce_total_length / KThreadSliceSize >= 32) + return (false); + + return (true); + }; + + std::unique_ptr + MakeArgumentPointer(const std::array inLengths, + const std::array inStrides, + const std::array, NumDTensor> DsLengths, + const std::array, NumDTensor> DsStrides, + const std::array outLengths, + const std::array outStrides, + const std::array reduceDims, + const void* in_dev, + const std::array ds_dev, + void* out_dev, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + DsLengths, + DsStrides, + outLengths, + outStrides, + reduceDims, + static_cast(in_dev), + ds_dev, + static_cast(out_dev), + in_elementwise_op, + out_elementwise_op); + }; + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceThreadWiseMultiD<" << BlockSize << ","; + str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ","; + str << "K_C" << 1 << "_S" << KThreadSliceSize << ","; + str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp new file mode 100755 index 000000000000..8eff9d241509 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp @@ -0,0 +1,417 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/tensor_operation/gpu/device/device_softmax.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_softmax.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSoftmaxImpl : public DeviceSoftmax +{ + static constexpr index_t NumInvariantDim = Rank - NumReduceDim; + + static constexpr index_t NumSrcDim = Rank; + static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim; + static constexpr bool reduceAllDim = (NumInvariantDim == 0); + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + static auto MakeSrc2dDescriptor(const std::vector& inLengths, + const std::vector& inStrides, + int blkGroupSize, + int numBlockTileIteration) + { + const auto tupleSrcLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + const auto tupleSrcStrides = + generate_tuple([&](auto I) { return inStrides[I]; }, Number{}); + + const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); + + const auto in_grid_desc_m_k = [&]() { + if constexpr(reduceAllDim) + { + const auto one_dim_inDesc = transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(tupleSrcLengths)), + make_tuple(typename arithmetic_sequence_gen<0, NumSrcDim, 1>::type{}), + make_tuple(Sequence<0>{})); + + return transform_tensor_descriptor(one_dim_inDesc, + make_tuple(make_unmerge_transform(make_tuple( + 1, one_dim_inDesc.GetLength(Number<0>{})))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + } + else + { + using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type; + using ReduceDims = typename arithmetic_sequence_gen::type; + + const auto reduceDimLengths = generate_tuple( + [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number{}); + const auto invariantDimLengths = + generate_tuple([&](auto I) { return inLengths[I]; }, Number{}); + + return transform_tensor_descriptor( + inDesc, + make_tuple(make_merge_transform(invariantDimLengths), + make_merge_transform(reduceDimLengths)), + make_tuple(InvariantDims{}, ReduceDims{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + }(); + + const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{}); + const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration; + const auto inPad_M = + math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength; + const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength; + + auto in_grid_desc_m_k_padded = transform_tensor_descriptor( + in_grid_desc_m_k, + make_tuple(make_right_pad_transform(invariantLength, inPad_M), + make_right_pad_transform(reduceLength, inPad_K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return (in_grid_desc_m_k_padded); + }; + + using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)); + + using GridwiseSoftmaxGeneric = GridwiseSoftmax_mk_to_mk; + + using GridwiseSoftmaxSweepOnce = GridwiseSoftmax_mk_to_mk; + + struct Argument : public BaseArgument + { + Argument(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + double alpha, + double beta, + const InDataType* in_dev, + OutDataType* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) + : in_dev_{in_dev}, + out_dev_{out_dev}, + in_elementwise_op_{in_elementwise_op}, + acc_elementwise_op_{acc_elementwise_op} + { + alpha_ = static_cast(alpha); + beta_ = static_cast(beta); + + if(Rank != inLengths.size() || Rank != inStrides.size() || + NumReduceDim != reduceDims.size()) + { + throw std::runtime_error( + "One of inLengths/inStrides/reduceDims has invalid size!" + "\nExpected size inLengths: " + + std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) + + ", reduceDims: " + std::to_string(NumReduceDim) + + "\nBut have inLengths: " + std::to_string(inLengths.size()) + + ", inStrides: " + std::to_string(inStrides.size()) + + ", reduceDims: " + std::to_string(reduceDims.size())); + } + + for(std::size_t i = 0; i < reduceDims.size(); ++i) + { + if(reduceDims[i] < 0 || reduceDims[i] >= Rank) + { + throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!" + "\nHave reduceDims[" + + std::to_string(i) + + "]: " + std::to_string(reduceDims[i])); + } + } + + inLengths_ = shuffle_tensor_dimensions(inLengths, reduceDims); + inStrides_ = shuffle_tensor_dimensions(inStrides, reduceDims); + + long_index_t invariant_total_length; + long_index_t reduce_total_length; + + std::tie(invariant_total_length, reduce_total_length) = + get_2d_lengths(inLengths_); + + if constexpr(NumInvariantDim == 0) + invariant_lowest_length_ = 1; + else + invariant_lowest_length_ = inLengths_[NumInvariantDim - 1]; + + blkGroupSize = 1; + numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize; + + gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) / + M_BlockTileSize * blkGroupSize; + } + + std::vector inLengths_; + std::vector inStrides_; + + AccDataType alpha_; + AccDataType beta_; + + const InDataType* in_dev_; + OutDataType* out_dev_; + + InElementwiseOp in_elementwise_op_; + AccElementwiseOp acc_elementwise_op_; + + index_t invariant_lowest_length_; + + int blkGroupSize; + int numBlockTileIteration; + size_t gridSize; + }; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor( + arg.inLengths_, arg.inStrides_, arg.blkGroupSize, arg.numBlockTileIteration); + + bool sweep_once = + in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize; + + const auto kernel_main = sweep_once ? kernel_softmax + : kernel_softmax; + + float avg_time = 0; + + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.gridSize), + dim3(BlockSize), + 0, + in_grid_desc_m_k, + out_grid_desc_m_k, + arg.blkGroupSize, + arg.numBlockTileIteration, + arg.alpha_, + arg.in_dev_, + arg.beta_, + arg.out_dev_); + + return (avg_time); + }; + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if constexpr(InSrcVectorDim == 0) + { + if constexpr(NumInvariantDim == 0) + { + return false; + } + else + { + if(arg.inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1) + { + return false; + } + if(arg.invariant_lowest_length_ % InSrcVectorSize != 0) + { + return false; + } + } + } + else + { + if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1) + { + return false; + } + if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0) + { + return false; + } + } + + // To improve + if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0) + { + return false; + } + + if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0) + { + return false; + } + + return true; + }; + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + double alpha, + double beta, + const InDataType* in_dev, + OutDataType* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) + { + return Argument{inLengths, + inStrides, + reduceDims, + alpha, + beta, + in_dev, + out_dev, + in_elementwise_op, + acc_elementwise_op}; + }; + + // + // @brief Makes a pointer to Argument class. + // + // @param[in] inLengths Input tensor extent(s) from high to low dimension + // @param[in] inStrides Input tensor stride(s) from high to low dimension + // @param[in] reduceDims The dimension(s) the normalization operation is applied + // @param[in] alpha Typeless pointer in host memory storing the alpha scaling + // value as type AccDataType + // @param[in] beta Typeless pointer in host memory storing the beta scaling + // value as type AccDataType + // @param[in] in_dev Typeless const pointer in device memory storing the input + // tensor + // @param out_dev Typeless pointer in device memory storing the output tensor + // @param[in] in_elementwise_op The input elementwise operation. + // @param[in] acc_elementwise_op The accumulation elementwise operation. + // + // @return Unique pointer to the Argument class. + // + std::unique_ptr MakeArgumentPointer(const std::vector inLengths, + const std::vector inStrides, + const std::vector reduceDims, + double alpha, + double beta, + const void* in_dev, + void* out_dev, + InElementwiseOp in_elementwise_op, + AccElementwiseOp acc_elementwise_op) override + { + return std::make_unique(inLengths, + inStrides, + reduceDims, + alpha, + beta, + static_cast(in_dev), + static_cast(out_dev), + in_elementwise_op, + acc_elementwise_op); + }; + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(); + }; + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceReduceSoftmax<" + << Rank << "," << NumReduceDim << "," << BlockSize << "," + << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << "," + << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << "," + << "InSrcVectorDim_" << InSrcVectorDim + << "_InSrcVectorSize_" << InSrcVectorSize + << "_OutDstVectorSize_" << OutDstVectorSize << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp new file mode 100755 index 000000000000..df3c929c2ea9 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp @@ -0,0 +1,197 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#if __clang_major__ == 20 +#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp" +#else +#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp" +#endif + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceSparseEmbeddingsForwardLayernorm : public BaseOperator +{ + static auto MakeOutputDescriptor(const index_t index_length, const index_t rows) + { + return make_naive_tensor_descriptor_packed(make_tuple(index_length, rows)); + } + + struct Argument : public BaseArgument + { + Argument(OutType* p_out, + const ck::Array& p_embs, + const ck::Array& p_indexs, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const ck::index_t EmbeddingDim, + const ck::index_t IndexLength, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) + : p_out_(p_out), + p_embs_(p_embs), + p_indexs_(p_indexs), + p_gamma_(p_gamma), + p_beta_(p_beta), + EmbeddingDim_(EmbeddingDim), + IndexLength_(IndexLength), + epsilon_(epsilon), + emb_elementwise_op_(emb_elementwise_op) + { + grid_size_ = (IndexLength + DimClusterSize - 1) / DimClusterSize; + } + + OutType* p_out_; + ck::Array p_embs_; + ck::Array p_indexs_; + const GammaDataType* p_gamma_; + const BetaDataType* p_beta_; + ck::index_t EmbeddingDim_; + ck::index_t IndexLength_; + AccDataType epsilon_; + EmbElementwiseOperation emb_elementwise_op_; + + size_t grid_size_; + }; + + std::unique_ptr + MakeArgumentPointer(void* p_out, + const ck::Array& p_embs, + const ck::Array& p_indexs, + const void* p_gamma, + const void* p_beta, + ck::index_t EmbeddingDim, + ck::index_t IndexLength, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) + { + return std::make_unique(reinterpret_cast(p_out), + p_embs, + p_indexs, + reinterpret_cast(p_gamma), + reinterpret_cast(p_beta), + EmbeddingDim, + IndexLength, + epsilon, + emb_elementwise_op); + } + + using GridwiseSparseEmbedding = + GridwiseSparseEmbeddingsForwardLayernorm; + + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + auto out_desc = MakeOutputDescriptor(arg.IndexLength_, arg.EmbeddingDim_); + const auto kernel_main = + kernel_sparse_embeddings_forward_layernorm; + float avg_time = 0; + avg_time += launch_and_time_kernel(stream_config, + kernel_main, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + arg.p_out_, + arg.p_embs_, + arg.p_indexs_, + arg.p_gamma_, + arg.p_beta_, + out_desc, + arg.epsilon_, + arg.emb_elementwise_op_); + + return (avg_time); + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + }; + }; + + static bool IsSupportedArgument(const Argument* p_arg) + { + return (RowPerBlock == p_arg->EmbeddingDim_); + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(dynamic_cast(p_arg)); + } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceSparseEmbeddingsForwardLayernorm_"<< BlockSize << "_" << + DimClusterSize << "x" << RowClusterSize << "_" << + DimPerBlock << "x" << RowPerBlock << "_" << + DimThreadSize << "x" << RowVectorSize; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..63b49d9aa0c7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/impl/device_splitk_contraction_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1150 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_splitk_contraction_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_xdl_cshuffle( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatDsPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + const index_t batch_count, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + FloatDsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseGemm::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_akb_ak0_m_ak1; + ignore = b_grid_desc_bkb_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +} // namespace ck + +namespace ck { +namespace tensor_operation { +namespace device { + +// Tensor Contraction: +// input : A +// input : B +// input : D0, D1, ... +// output : E +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] +// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] +// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...] +template +struct DeviceSplitKContractionMultipleD_Xdl_CShuffle + : public DeviceSplitKContractionMultipleD +{ + using DeviceOp = DeviceSplitKContractionMultipleD_Xdl_CShuffle; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto matrix_padder = + MatrixPadder{MPerBlock, NPerBlock, KPerBlock}; + + // Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...] + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + assert(a_gs_ms_ks_lengths_vec.size() == NumDimG + NumDimM + NumDimK && + a_gs_ms_ks_strides_vec.size() == NumDimG + NumDimM + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto a_ms_ks_lengths = to_tuple( + a_gs_ms_ks_lengths_vec, Number{}, Number{}); + const auto a_ms_ks_strides = to_tuple( + a_gs_ms_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds); + + if constexpr(ASpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto a_grid_desc_mraw_kraw = make_naive_tensor_descriptor( + make_tuple(M, K), + make_tuple(a_ms_ks_strides[Number{}], + a_ms_ks_strides[Number{}])); + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + else + { + // naive tensor A[M0, M1, M2, ..., K0, K1, K2...] + const auto a_grid_desc_ms_ks = + make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides); + + // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] + const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( + a_grid_desc_ms_ks, + make_tuple(make_merge_transform(mLengths), make_merge_transform(kLengths)), + make_tuple(mDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + } + + // Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...] + static auto MakeBGridDescriptor_N_K(const std::vector& b_gs_ns_ks_lengths_vec, + const std::vector& b_gs_ns_ks_strides_vec) + { + assert(b_gs_ns_ks_lengths_vec.size() == NumDimG + NumDimN + NumDimK && + b_gs_ns_ks_strides_vec.size() == NumDimG + NumDimN + NumDimK); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto b_ns_ks_lengths = to_tuple( + b_gs_ns_ks_lengths_vec, Number{}, Number{}); + const auto b_ns_ks_strides = to_tuple( + b_gs_ns_ks_strides_vec, Number{}, Number{}); + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{}; + + // dimension Ids for K0, K1, ... + constexpr auto kDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for K0, K1, ... + const auto kLengths = get_container_subset(b_ns_ks_lengths, kDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(b_ns_ks_lengths, nDimIds); + + if constexpr(BSpec == TensorSpecialization::Packed) + { + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + auto K = container_reduce(kLengths, math::multiplies{}, Number<1>{}); + const auto b_grid_desc_nraw_kraw = make_naive_tensor_descriptor( + make_tuple(N, K), + make_tuple(b_ns_ks_strides[Number{}], + b_ns_ks_strides[Number{}])); + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + else + { + // naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...] + const auto b_grid_desc_ns_ks = + make_naive_tensor_descriptor(b_ns_ks_lengths, b_ns_ks_strides); + + // transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...] + const auto b_grid_desc_nraw_kraw = transform_tensor_descriptor( + b_grid_desc_ns_ks, + make_tuple(make_merge_transform(nLengths), make_merge_transform(kLengths)), + make_tuple(nDimIds, kDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_ms_ns_lengths = to_tuple( + e_gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto e_ms_ns_strides = to_tuple( + e_gs_ms_ns_strides_vec, Number{}, Number{}); + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(e_ms_ns_strides[Number{}], + e_ms_ns_strides[Number{}])); + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + else + { + // naive tensor E[M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_ms_ns = + make_naive_tensor_descriptor(e_ms_ns_lengths, e_ms_ns_strides); + + // transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...] + const auto e_grid_desc_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + } + + // assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + static auto MakeEGridDescriptor_G_M_N(const std::vector& e_gs_ms_ns_lengths_vec, + const std::vector& e_gs_ms_ns_strides_vec) + { + assert(e_gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + e_gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN); + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto e_gs_ms_ns_lengths = + to_tuple(e_gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto e_gs_ms_ns_strides = + to_tuple(e_gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(e_gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(e_gs_ms_ns_lengths, mDimIds); + + // lengths for K0, K1, ... + const auto nLengths = get_container_subset(e_gs_ms_ns_lengths, nDimIds); + + if constexpr(DESpec == TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto e_grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}], + e_gs_ms_ns_strides[Number{}])); + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + else + { + // naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto e_grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(e_gs_ms_ns_lengths, e_gs_ms_ns_strides); + + // transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto e_grid_desc_g_mraw_nraw = transform_tensor_descriptor( + e_grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw); + return e_grid_desc_g_mraw_nraw; + } + } + + static auto MakeDsGridDescriptor_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + static auto MakeDsGridDescriptor_G_M_N( + const std::array, NumDTensor>& ds_gs_ms_ns_lengths_vec, + const std::array, NumDTensor>& ds_gs_ms_ns_strides_vec) + { + return generate_tuple( + [&](auto i) { + return DeviceOp::MakeEGridDescriptor_G_M_N(ds_gs_ms_ns_lengths_vec[i], + ds_gs_ms_ns_strides_vec[i]); + }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K({}, {})); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K({}, {})); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); + + using DsGridDesc_G_M_N = remove_cvref_t; + using EGridDesc_G_M_N = decltype(MakeEGridDescriptor_G_M_N({}, {})); + + struct ComputePtrOffsetOfStridedBatch + { + ComputePtrOffsetOfStridedBatch(index_t batch_stride_A, + index_t batch_stride_B, + DsGridDesc_G_M_N ds_grid_desc_g_m_n, + EGridDesc_G_M_N e_grid_desc_g_m_n) + : batch_stride_A_(batch_stride_A), + batch_stride_B_(batch_stride_B), + ds_grid_desc_g_m_n_(ds_grid_desc_g_m_n), + e_grid_desc_g_m_n_(e_grid_desc_g_m_n) + { + } + + __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(batch_stride_A_); + } + + __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const + { + return g_idx * static_cast(batch_stride_B_); + } + + __host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const + { + std::array ds_offset; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_offset[i] = static_cast(g_idx) * + ds_grid_desc_g_m_n_[i].CalculateOffset(make_multi_index(1, 0, 0)); + }); + + return ds_offset; + } + + __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const + { + return static_cast(g_idx) * + e_grid_desc_g_m_n_.CalculateOffset(make_multi_index(1, 0, 0)); + } + + private: + index_t batch_stride_A_; + index_t batch_stride_B_; + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + }; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::Set, + AGridDesc_M_K, + BGridDesc_N_K, + DsGridDesc_M_N, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + // GridwiseGemm + using GridwiseGemmAtomicAdd = GridwiseGemmSplitKMultipleD_xdl_cshuffle< + ADataType, // TODO: distinguish A/B datatype + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + InMemoryDataOperationEnum::AtomicAdd, + AGridDesc_M_K, + BGridDesc_N_K, + DsGridDesc_M_N, + EGridDesc_M_N, + NumGemmKPrefetchStage, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CDEBlockTransferScalarPerVector_NPerBlock, + LoopSched>; + + using AGridDesc_AKB_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BKB_BK0_N_BK1 = + remove_cvref_t; + + using Block2ETileMap = typename GridwiseGemm::DefaultBlock2ETileMap; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{ + DeviceOp::MakeAGridDescriptor_M_K(a_gs_ms_ns_lengths, a_gs_ms_ks_strides)}, + b_grid_desc_n_k_{ + DeviceOp::MakeBGridDescriptor_N_K(b_gs_ns_ks_lengths, b_gs_ns_ks_strides)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{ + DeviceOp::MakeEGridDescriptor_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + ds_grid_desc_g_m_n_{ + DeviceOp::MakeDsGridDescriptor_G_M_N(ds_gs_ms_ns_lengths, ds_gs_ms_ns_strides)}, + e_grid_desc_g_m_n_{ + DeviceOp::MakeEGridDescriptor_G_M_N(e_gs_ms_ns_lengths, e_gs_ms_ns_strides)}, + a_grid_desc_akb_ak0_m_ak1_{GridwiseGemm::MakeDefaultAGridDescriptor_AKB_AK0_M_AK1( + a_grid_desc_m_k_, split_k)}, + b_grid_desc_bkb_bk0_n_bk1_{GridwiseGemm::MakeDefaultBGridDescriptor_BKB_BK0_N_BK1( + b_grid_desc_n_k_, split_k)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{ + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_, split_k)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_mz_stride_{}, + a_kz_stride_{}, + b_nz_stride_{}, + b_kz_stride_{}, + ds_nz_stride_{}, + e_nz_stride_{}, + a_batch_stride_{a_gs_ms_ks_strides[NumDimG - 1]}, + b_batch_stride_{b_gs_ns_ks_strides[NumDimG - 1]}, + compute_ptr_offset_of_batch_{ + a_batch_stride_, b_batch_stride_, ds_grid_desc_g_m_n_, e_grid_desc_g_m_n_}, + split_k_{split_k} + { + static_assert(NumDimG > 0 && NumDimM > 0 && NumDimN > 0 && NumDimK > 0, ""); + + // populate pointer, batch stride, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + + // D pointer + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + + // D desc + ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N(ds_gs_ms_ns_lengths[i], + ds_gs_ms_ns_strides[i]); + }); + + // populate desc for Ds/E + if(GridwiseGemm::CheckValidity(a_grid_desc_akb_ak0_m_ak1_, + b_grid_desc_bkb_bk0_n_bk1_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + e_grid_desc_m_n_); + + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n_); + } + + // for sanity check of vector memory access + a_mz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM - 1]; + a_kz_stride_ = a_gs_ms_ks_strides[NumDimG + NumDimM + NumDimK - 1]; + b_nz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN - 1]; + b_kz_stride_ = b_gs_ns_ks_strides[NumDimG + NumDimN + NumDimK - 1]; + + for(index_t i = 0; i < NumDTensor; ++i) + { + ds_nz_stride_[i] = ds_gs_ms_ns_strides[i][NumDimG + NumDimM + NumDimN - 1]; + } + + e_nz_stride_ = e_gs_ms_ns_strides[NumDimG + NumDimM + NumDimN - 1]; + + Print(); + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_.GetLength(I0) << ", " + << a_grid_desc_m_k_.GetLength(I1) << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_.GetLength(I0) << ", " + << b_grid_desc_n_k_.GetLength(I1) << std::endl; + + std::cout << "A[akb, ak0, m, ak1]: " << a_grid_desc_akb_ak0_m_ak1_.GetLength(I0) << ", " + << a_grid_desc_akb_ak0_m_ak1_.GetLength(I1) << ", " + << a_grid_desc_akb_ak0_m_ak1_.GetLength(I2) << ", " + << a_grid_desc_akb_ak0_m_ak1_.GetLength(I3) << std::endl; + std::cout << "B[bkb, bk0, n, bk1]: " << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I0) << ", " + << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I1) << ", " + << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I2) << ", " + << b_grid_desc_bkb_bk0_n_bk1_.GetLength(I3) << std::endl; + static_for<0, NumDTensor, 1>{}([&](auto i) { + std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i].GetLength(I0) << ", " + << ds_grid_desc_m_n_[i].GetLength(I1) << std::endl; + }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_.GetLength(I0) << ", " + << e_grid_desc_m_n_.GetLength(I1) << std::endl; + } + + // private: + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + DsGridDesc_G_M_N ds_grid_desc_g_m_n_; + EGridDesc_G_M_N e_grid_desc_g_m_n_; + + // tensor descriptors for block/thread-wise copy + AGridDesc_AKB_AK0_M_AK1 a_grid_desc_akb_ak0_m_ak1_; + BGridDesc_BKB_BK0_N_BK1 b_grid_desc_bkb_bk0_n_bk1_; + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise op + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // Strides for the last M/N/K dimensions of A/B/Ds/E + // for sanity check of vector load/store + index_t a_mz_stride_; + index_t a_kz_stride_; + index_t b_nz_stride_; + index_t b_kz_stride_; + std::array ds_nz_stride_; + index_t e_mz_stride_; + index_t e_nz_stride_; + + index_t a_batch_stride_; + index_t b_batch_stride_; + + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + index_t split_k_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + throw std::runtime_error( + "wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"); + } + + const index_t G = arg.e_grid_desc_g_m_n_.GetLength(I0); + + const index_t grid_size = + arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * G; + + const auto K = arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I1) * + arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I3); + + auto launch_kernel = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemm, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemm::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AKB_AK0_M_AK1, + DeviceOp::BGridDesc_BKB_BK0_N_BK1, + typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + typename GridwiseGemm::DefaultBlock2ETileMap, + has_main_loop>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + auto launch_kernel_atomic_add = [&](auto has_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop.value; + + const auto kernel = kernel_contraction_multiple_d_xdl_cshuffle< + GridwiseGemmAtomicAdd, + ADataType, // TODO: distiguish A/B datatype + typename GridwiseGemmAtomicAdd::DsGridPointer, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + DeviceOp::AGridDesc_AKB_AK0_M_AK1, + DeviceOp::BGridDesc_BKB_BK0_N_BK1, + typename GridwiseGemmAtomicAdd:: + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + typename GridwiseGemmAtomicAdd:: + EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, + ComputePtrOffsetOfStridedBatch, + typename GridwiseGemmAtomicAdd::DefaultBlock2ETileMap, + has_main_loop>; + + hipGetErrorString(hipMemsetAsync( + arg.p_e_grid_, + 0, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetElementSpaceSize() * + sizeof(EDataType), + stream_config.stream_id_)); + + return launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(BlockSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_ds_grid_, + arg.p_e_grid_, + G, + arg.a_element_op_, + arg.b_element_op_, + arg.cde_element_op_, + arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.compute_ptr_offset_of_batch_, + arg.block_2_etile_map_); + }; + + if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) + { + if(arg.split_k_ <= 1) + return launch_kernel(integral_constant{}); + else + return launch_kernel_atomic_add(integral_constant{}); + } + else + { + if(arg.split_k_ <= 1) + return launch_kernel(integral_constant{}); + else + return launch_kernel_atomic_add(integral_constant{}); + } + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_akb_ak0_m_ak1_, + arg.b_grid_desc_bkb_bk0_n_bk1_, + arg.ds_grid_desc_m_n_, + arg.e_grid_desc_m_n_, + arg.block_2_etile_map_)) + { + return false; + } + + // check vector access + static_assert((ABlockTransferSrcVectorDim == 2 || ABlockTransferSrcVectorDim == 3) && + (BBlockTransferSrcVectorDim == 2 || BBlockTransferSrcVectorDim == 3), + "wrong!"); + + // vector memory access of A: could be on M or AK1 dimension + if constexpr(ABlockTransferSrcVectorDim == 2) + { + if(!(arg.a_mz_stride_ == 1 && + arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + else + { + if(!(arg.a_kz_stride_ == 1 && + arg.a_grid_desc_akb_ak0_m_ak1_.GetLength(I3) % ABlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + + // vector memory access of B: could be on N or BK1 dimension + if constexpr(BBlockTransferSrcVectorDim == 2) + { + if(!(arg.b_nz_stride_ == 1 && + arg.b_grid_desc_bkb_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + else + { + if(!(arg.b_kz_stride_ == 1 && + arg.b_grid_desc_bkb_bk0_n_bk1_.GetLength(I3) % BBlockTransferSrcScalarPerVector == + 0)) + { + return false; + } + } + + // vector memory access of Ds: always on NPerBlock dimension + bool valid_d_access = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + if(!(arg.ds_nz_stride_[i] == 1 && + arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0)) + { + valid_d_access = false; + } + }); + + if(valid_d_access == false) + { + return false; + } + + // vector memory access of E: always on NPerBlock dimension + if(!((arg.e_nz_stride_ == 1 && + arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % + CDEBlockTransferScalarPerVector_NPerBlock == + 0) || + CDEBlockTransferScalarPerVector_NPerBlock == 1)) + { + return false; + } + + return true; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const std::vector& a_gs_ms_ns_lengths, + const std::vector& a_gs_ms_ks_strides, + const std::vector& b_gs_ns_ks_lengths, + const std::vector& b_gs_ns_ks_strides, + const std::array, NumDTensor>& ds_gs_ms_ns_lengths, + const std::array, NumDTensor>& ds_gs_ms_ns_strides, + const std::vector& e_gs_ms_ns_lengths, + const std::vector& e_gs_ms_ns_strides, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t split_k) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_gs_ms_ns_lengths, + a_gs_ms_ks_strides, + b_gs_ns_ks_lengths, + b_gs_ns_ks_strides, + ds_gs_ms_ns_lengths, + ds_gs_ms_ns_strides, + e_gs_ms_ns_lengths, + e_gs_ms_ns_strides, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceSplitKContractionMultipleD_Xdl_CShuffle" + << "<" + << NumDimG << ", " + << NumDimM << ", " + << NumDimN << ", " + << NumDimK << ", " + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << ABlockTransferSrcVectorDim << ", " + << BBlockTransferSrcVectorDim + << ">"; + // clang-format on + + return str.str(); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/masking_specialization.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/masking_specialization.hpp new file mode 100755 index 000000000000..9fe2f0d9767c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/masking_specialization.hpp @@ -0,0 +1,87 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct MaskingSpecialization +{ + MaskDisabled, + MaskOutUpperTriangle +}; + +#ifndef __HIPCC_RTC__ +inline std::string getMaskingSpecializationString(const MaskingSpecialization& s) +{ + switch(s) + { + case MaskingSpecialization::MaskDisabled: return "MaskDisabled"; + case MaskingSpecialization::MaskOutUpperTriangle: return "MaskOutUpperTriangle"; + default: return "Unrecognized specialization!"; + } +} +#endif + +struct MaskDisabledPredicate +{ + __host__ __device__ constexpr bool operator()(index_t /*m*/, index_t /*n*/) const + { + return false; + }; + + __host__ __device__ constexpr bool + IsTileSkippable(index_t /*m*/, index_t /*n*/, index_t /*m_tile*/, index_t /*n_tile*/) const + { + return false; + } +}; + +struct MaskOutUpperTrianglePredicate +{ + __host__ __device__ constexpr bool operator()(index_t m, index_t n) const { return n > m; } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t /*n_tile*/) const + { + return operator()(m + m_tile - 1, n); + } +}; + +// to track the points which need to be set to -inf on C0 +// Note: no need to reset M padding value, because they will not be stored out. +template +struct C0MatrixMask_impl +{ + __host__ __device__ constexpr C0MatrixMask_impl(index_t NRaw) + : NRaw_(NRaw), predicate_(MaskOutPredicate{}) + { + } + + __host__ __device__ constexpr bool IsNOutOfBound(/*index_t m, */ index_t n) const + { + return n >= NRaw_; + } + + __host__ __device__ constexpr bool IsMaskedElement(index_t m, index_t n) const + { + return predicate_(m, n) || IsNOutOfBound(n); + } + + __host__ __device__ constexpr bool + IsTileSkippable(index_t m, index_t n, index_t m_tile, index_t n_tile) const + { + return predicate_.IsTileSkippable(m, n, m_tile, n_tile); + } + + private: + // index_t MRaw_; + index_t NRaw_; + MaskOutPredicate predicate_; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/matrix_padder.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/matrix_padder.hpp new file mode 100755 index 000000000000..02941531475f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/matrix_padder.hpp @@ -0,0 +1,395 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template + typename DoPads> // Sequence +__host__ __device__ constexpr auto +PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads) +{ + constexpr index_t num_dim = DoPads::Size(); + + static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(), + "wrong! inconsistent # of dimensions"); + + // transforms + const auto transforms = generate_tuple( + [&](auto idim) { + const auto MRaw = desc.GetLength(idim); + + const auto MPerTile = tile_lengths[idim]; + + const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile; + + const auto MPad = M - MRaw; + + const bool DoPadM = DoPads::At(idim); + + const auto MTransform = conditional_expr(make_right_pad_transform(MRaw, MPad), + make_pass_through_transform(MRaw)); + + return MTransform; + }, + Number{}); + + // lower dimension Id + const auto lower_dimss = + generate_tuple([&](auto idim) { return Sequence{}; }, Number{}); + + // upper dimension Id + const auto upper_dimss = lower_dimss; + + return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss); +} + +// M/N/K/OPerTileType could be index_t or Number<> +template +struct GemmGemmPadder +{ + // TODO: hard to scale; use mask instead + static constexpr bool PadM = + GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::MOPadding || GemmSpec == GemmSpecialization::MNOPadding || + GemmSpec == GemmSpecialization::MKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + static constexpr bool PadN = + GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::MNOPadding || + GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + static constexpr bool PadK = + GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KOPadding || GemmSpec == GemmSpecialization::MKOPadding || + GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + static constexpr bool PadO = + GemmSpec == GemmSpecialization::OPadding || GemmSpec == GemmSpecialization::MOPadding || + GemmSpec == GemmSpecialization::NOPadding || GemmSpec == GemmSpecialization::KOPadding || + GemmSpec == GemmSpecialization::MNOPadding || GemmSpec == GemmSpecialization::MKOPadding || + GemmSpec == GemmSpecialization::NKOPadding || GemmSpec == GemmSpecialization::MNKOPadding; + + // A[M, K] + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); + } + + // B[K, N] + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); + } + + // B1[Gemm1N, Gemm1K] = B1[O, N] + template + __host__ __device__ constexpr auto + PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence{}); + } + + // C[M, Gemm1N] = C[M, O] + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence{}); + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; + OPerTileType OPerTile_; +}; + +// M/N/KPerTileType could be index_t or Number<> +template +struct GemmPadder +{ + static constexpr bool PadM = + (GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding); + static constexpr bool PadN = + (GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding); + static constexpr bool PadK = + (GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding); + + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence{}); + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; +}; + +// Alias of GemmPadder; to deprecate +template +struct MatrixPadder : public GemmPadder +{ +}; + +// function to take in a struct of type MatrixPadder and call the appropriate function to get +// the output descriptor at runtime for codegen +template +auto grid_desc(MatrixPadder matrix_padder, + CDesc_MRaw_NRaw conv_desc) +{ + auto res = matrix_padder.PadCDescriptor_M_N(conv_desc); + return res; +} +// M/N/KPerTileType could be index_t or Number<> +template +struct GemmPadder_v2 +{ + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + return PadTensorDescriptor( + a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + return PadTensorDescriptor( + b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence{}); + } + + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + return PadTensorDescriptor( + c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence{}); + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; +}; + +// M/N/KPerTileType could be index_t or Number<> +template +struct MatrixPadder_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + template + __host__ __device__ constexpr auto + PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const + { + const auto MRaw = a_desc_mraw_kraw.GetLength(I0); + const auto KRaw = a_desc_mraw_kraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_; + const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_; + + const auto MPad = M - MRaw; + const auto KPad = K - KRaw; + + if constexpr(PadM && PadK) + { + // pad both M and K + return transform_tensor_descriptor(a_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadM && (!PadK)) + { + // pad M, but not K + return transform_tensor_descriptor( + a_desc_mraw_kraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(KRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr((!PadM) && PadK) + { + // pad K, but not M + return transform_tensor_descriptor( + a_desc_mraw_kraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or K + return a_desc_mraw_kraw; + } + } + + template + __host__ __device__ constexpr auto + PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const + { + const auto NRaw = b_desc_nraw_kraw.GetLength(I0); + const auto KRaw = b_desc_nraw_kraw.GetLength(I1); + + const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_; + const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_; + + const auto NPad = N - NRaw; + const auto KPad = K - KRaw; + + if constexpr(PadN && PadK) + { + // pad both N and K + return transform_tensor_descriptor(b_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), + make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadN && (!PadK)) + { + // pad N, but not K + return transform_tensor_descriptor( + b_desc_nraw_kraw, + make_tuple(make_right_pad_transform(NRaw, NPad), make_pass_through_transform(KRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr((!PadN) && PadK) + { + // pad K, but not N + return transform_tensor_descriptor( + b_desc_nraw_kraw, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad N or K + return b_desc_nraw_kraw; + } + } + + template + __host__ __device__ constexpr auto + PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const + { + const auto MRaw = c_desc_mraw_nraw.GetLength(I0); + const auto NRaw = c_desc_mraw_nraw.GetLength(I1); + + const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_; + const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_; + + const auto MPad = M - MRaw; + const auto NPad = N - NRaw; + + if constexpr(PadM && PadN) + { + // pad M and N + return transform_tensor_descriptor(c_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), + make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(PadM && (!PadN)) + { + // pad M, but not N + return transform_tensor_descriptor( + c_desc_mraw_nraw, + make_tuple(make_right_pad_transform(MRaw, MPad), make_pass_through_transform(NRaw)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr((!PadM) && PadN) + { + // pad N, but not M + return transform_tensor_descriptor( + c_desc_mraw_nraw, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(NRaw, NPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_desc_mraw_nraw; + } + } + + MPerTileType MPerTile_; + NPerTileType NPerTile_; + KPerTileType KPerTile_; +}; +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp new file mode 100755 index 000000000000..5351d4ef2439 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_enums.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +// FIXME: can it be replaced with ck::Tuple? +#include + +namespace ck { + +// The templated struct reduce_binary_operator maps the enum Ids of binary operators to their +// respective functor classes. +// The boolean member "indexable" are also provided in reduce_binary_operactor for +// easier checking by the upper-layer codes in the kernels. + +template +struct reduce_binary_operator; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Mul; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Min; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Max; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::AMax; + + static constexpr bool indexable = true; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +template <> +struct reduce_binary_operator +{ + using opType = reduce::Add; + + static constexpr bool indexable = false; +}; + +// The templated struct reduce_unary_operator maps the enum Ids of Reduce operators to two unary +// functor classes. +// The two unary functors are called before and afer the Reduction is executed respectively +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnaryDivide; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{reduceLength}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnaryAbs; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::PassThrough; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::UnarySquare; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +template <> +struct reduce_unary_operator +{ + using InElementwiseOperation = tensor_operation::element_wise::PassThrough; + using AccElementwiseOperation = tensor_operation::element_wise::UnarySqrt; + + static std::tuple + GetElementwiseOperator(int32_t reduceLength) + { + (void)reduceLength; + return std::make_tuple(InElementwiseOperation{}, AccElementwiseOperation{}); + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/tensor_layout.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/tensor_layout.hpp new file mode 100755 index 000000000000..e836e73a1d7a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/tensor_layout.hpp @@ -0,0 +1,451 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_layout { + +struct BaseTensorLayout +{ +}; + +namespace gemm { + +struct RowMajor : public BaseTensorLayout +{ + static constexpr const char* name = "RowMajor"; +}; + +struct ColumnMajor : public BaseTensorLayout +{ + static constexpr const char* name = "ColumnMajor"; +}; + +struct MFMA : public BaseTensorLayout +{ + static constexpr const char* name = "MFMA"; +}; + +} // namespace gemm + +namespace convolution { + +// input tensor +// packed NCW/NCHW/NCDHW +struct NCW : public BaseTensorLayout +{ + static constexpr const char* name = "NCW"; +}; + +struct NCHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCHW"; +}; + +struct NCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NCDHW"; +}; + +// packed GNCW/GNCHW/GNCDHW +struct GNCW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCW"; +}; + +struct GNCHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCHW"; +}; + +struct GNCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNCDHW"; +}; + +// input tensor +// packed NWC/NHWC/NDHWC +struct NWC : public BaseTensorLayout +{ + static constexpr const char* name = "NWC"; +}; + +struct NHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWC"; +}; + +struct NDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWC"; +}; + +// input tensor +// packed GNWC/GNHWC/GNDHWC +struct GNWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNWC"; +}; + +struct GNHWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNHWC"; +}; + +struct GNDHWC : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHWC"; +}; + +// for input bias +struct GC : public BaseTensorLayout +{ + static constexpr const char* name = "GC"; +}; + +// input tensor +// packed NWGC/NHWGC/NDHWGC +struct NWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NWGC"; +}; + +struct NHWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NHWGC"; +}; + +struct NDHWGC : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWGC"; +}; + +// input tensor +// packed NGCW/NGCHW/NGCDHW +struct NGCW : public BaseTensorLayout +{ + static constexpr const char* name = "NGCW"; +}; + +struct NGCHW : public BaseTensorLayout +{ + static constexpr const char* name = "NGCHW"; +}; + +struct NGCDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NGCDHW"; +}; + +// input tensor +// strided layout +struct G_NW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW_C"; +}; + +struct G_NHW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW_C"; +}; + +struct G_NDHW_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW_C"; +}; + +// for input bias +struct G_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_C"; +}; + +// weight tensor +// packed KCX/KCYX/KCZYX +struct KCX : public BaseTensorLayout +{ + static constexpr const char* name = "KCX"; +}; + +struct KCYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCYX"; +}; + +struct KCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "KCZYX"; +}; + +// weight tensor +// packed KCX/KCYX/KCZYX +struct GKCX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCX"; +}; + +struct GKCYX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCYX"; +}; + +struct GKCZYX : public BaseTensorLayout +{ + static constexpr const char* name = "GKCZYX"; +}; + +// weight tensor +// packed KXC/KYXC/KZYXC +struct KXC : public BaseTensorLayout +{ + static constexpr const char* name = "KXC"; +}; + +struct KYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXC"; +}; + +struct KZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXC"; +}; + +// weight tensor +// packed GKXC/GKYXC/GKZYXC +struct GKXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKXC"; +}; + +struct GKYXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKYXC"; +}; + +struct GKZYXC : public BaseTensorLayout +{ + static constexpr const char* name = "GKZYXC"; +}; + +// weight tensor +// packed KXGC/KYXGC/KZYXGC +struct KXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KXGC"; +}; + +struct KYXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KYXGC"; +}; + +struct KZYXGC : public BaseTensorLayout +{ + static constexpr const char* name = "KZYXGC"; +}; + +// weight tensor +// strided +struct G_K_X_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_X_C"; +}; + +struct G_K_YX_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_YX_C"; +}; + +struct G_K_ZYX_C : public BaseTensorLayout +{ + static constexpr const char* name = "G_K_ZYX_C"; +}; + +// output tensor +// packed NKW/NKHW/NKDHW +struct NKW : public BaseTensorLayout +{ + static constexpr const char* name = "NKW"; +}; + +struct NKHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKHW"; +}; + +struct NKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NKDHW"; +}; + +// output tensor +// packed GNKW/GNKHW/GNKDHW +struct GNKW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKW"; +}; + +struct GNKHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKHW"; +}; + +struct GNKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNKDHW"; +}; + +// output tensor +// packed NWK/NHWK/NDHWK +struct NWK : public BaseTensorLayout +{ + static constexpr const char* name = "NWK"; +}; + +struct NHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWK"; +}; + +struct NDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWK"; +}; + +// output tensor +// packed GNWK/GNHWK/GNDHWK +struct GNWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNWK"; +}; + +struct GNHWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNHWK"; +}; + +struct GNDHWK : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHWK"; +}; + +// output tensor +// packed NWGK/NHWGK/NDHWGK +struct NWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NWGK"; +}; + +struct NHWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NHWGK"; +}; + +struct NDHWGK : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWGK"; +}; + +struct NGKW : public BaseTensorLayout +{ + static constexpr const char* name = "NGKW"; +}; + +struct NGKHW : public BaseTensorLayout +{ + static constexpr const char* name = "NGKHW"; +}; + +struct NGKDHW : public BaseTensorLayout +{ + static constexpr const char* name = "NGKDHW"; +}; + +// output tensor +// strided layout +struct G_NW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW_K"; +}; + +struct G_NHW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW_K"; +}; + +struct G_NDHW_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW_K"; +}; + +// for output bias +struct G_K : public BaseTensorLayout +{ + static constexpr const char* name = "G_K"; +}; + +// K-reduced output tensor (packed) +struct GNW : public BaseTensorLayout +{ + static constexpr const char* name = "GNW"; +}; + +struct GNHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNHW"; +}; + +struct GNDHW : public BaseTensorLayout +{ + static constexpr const char* name = "GNDHW"; +}; + +// K-reduced output tensor (packed) +struct NWG : public BaseTensorLayout +{ + static constexpr const char* name = "NWG"; +}; + +struct NHWG : public BaseTensorLayout +{ + static constexpr const char* name = "NHWG"; +}; + +struct NDHWG : public BaseTensorLayout +{ + static constexpr const char* name = "NDHWG"; +}; + +// K-reduced output tensor (strided) +struct G_NW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NW"; +}; + +struct G_NHW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NHW"; +}; + +struct G_NDHW : public BaseTensorLayout +{ + static constexpr const char* name = "G_NDHW"; +}; + +} // namespace convolution + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +template < + typename Layout, + typename std::enable_if::value, bool>::type = false> +std::ostream& operator<<(std::ostream& os, const Layout&) +{ + os << Layout::name; + return os; +} +#endif + +} // namespace tensor_layout +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp new file mode 100755 index 000000000000..713fc93ebb4e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/tensor_specialization.hpp @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +enum struct TensorSpecialization +{ + Default, + Packed +}; + +inline std::string getTensorSpecializationString(const TensorSpecialization& s) +{ + switch(s) + { + case TensorSpecialization::Default: return "Default"; + case TensorSpecialization::Packed: return "Packed"; + default: return "Unrecognized specialization!"; + } +} + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/welford_helper.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/welford_helper.hpp new file mode 100755 index 000000000000..d7772d8764ff --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/device/welford_helper.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct GetReduceCountPerThreadForBlockwiseWelford +{ + GetReduceCountPerThreadForBlockwiseWelford(index_t numBlockTileIteration, + long_index_t reduce_length) + : numBlockTileIteration_{numBlockTileIteration} + { + count_in_last_tile_ = reduce_length % K_BlockTileSize; + }; + + __device__ index_t operator()(index_t thread_k_cluster_id) const + { + if(count_in_last_tile_ == 0) + return (KThreadSliceSize * numBlockTileIteration_); + else + { + index_t num_complete_slice = count_in_last_tile_ / KThreadSliceSize; + index_t count_in_last_slice = count_in_last_tile_ % KThreadSliceSize; + + if(thread_k_cluster_id < num_complete_slice) + return (KThreadSliceSize * numBlockTileIteration_); + else if(thread_k_cluster_id == num_complete_slice) + return (KThreadSliceSize * (numBlockTileIteration_ - 1) + count_in_last_slice); + else + return (KThreadSliceSize * (numBlockTileIteration_ - 1)); + }; + }; + + index_t numBlockTileIteration_; + index_t count_in_last_tile_; +}; + +template +struct GetReduceCountPerThreadForMultiblockWelford +{ + GetReduceCountPerThreadForMultiblockWelford(index_t blkGroupSize, + index_t numBlockTileIteration, + long_index_t reduce_length) + : blkGroupSize_(blkGroupSize), numBlockTileIteration_{numBlockTileIteration} + { + last_block_reduce_length_ = + reduce_length - K_BlockTileSize * numBlockTileIteration_ * (blkGroupSize_ - 1); + numBlockTileIterationByLastBlock_ = + (last_block_reduce_length_ + K_BlockTileSize - 1) / K_BlockTileSize; + }; + + __device__ index_t operator()(index_t block_local_id, index_t thread_k_cluster_id) const + { + if(last_block_reduce_length_ == K_BlockTileSize * numBlockTileIteration_ || + block_local_id < blkGroupSize_ - 1) + return (KThreadSliceSize * numBlockTileIteration_); + + index_t count_in_last_tile = last_block_reduce_length_ % K_BlockTileSize; + + if(count_in_last_tile == 0) + return (KThreadSliceSize * numBlockTileIterationByLastBlock_); + else + { + index_t num_complete_slice = count_in_last_tile / KThreadSliceSize; + + if(thread_k_cluster_id < num_complete_slice) + return (KThreadSliceSize * numBlockTileIterationByLastBlock_); + else if(thread_k_cluster_id == num_complete_slice) + return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1) + + count_in_last_tile); + else + return (KThreadSliceSize * (numBlockTileIterationByLastBlock_ - 1)); + }; + }; + + index_t blkGroupSize_; + index_t numBlockTileIteration_; + + index_t last_block_reduce_length_; + index_t numBlockTileIterationByLastBlock_; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp new file mode 100755 index 000000000000..34c76b89e49b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp @@ -0,0 +1,769 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +struct Add +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 + type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 + x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 + x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 + x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp + x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 + x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 + x1; + }; +}; + +struct Max +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::max(x0_converted, x1_converted); + } +}; + +struct Min +{ + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + const Y x0_converted = type_convert(x0); + const Y x1_converted = type_convert(x1); + y = ck::math::min(x0_converted, x1_converted); + } +}; + +struct Multiply +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = x0 * type_convert(x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const float& x1) const + { + y = type_convert(x0 * x1); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(x0) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + y = x0 * x1_tmp; + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x0 * x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 * x1; + }; +}; + +struct ScaleAdd +{ + __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} + + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const + { + y = ck::type_convert(scale_ * ck::type_convert(x0) + ck::type_convert(x1)); + } + + template <> + __host__ __device__ void + operator()(float& y, const float& x0, const half_t& x1) const + { + y = scale_ * x0 + ck::type_convert(x1); + }; + + template <> + __host__ __device__ void + operator()(float& y, const float& x0, const bhalf_t& x1) const + { + y = scale_ * x0 + ck::type_convert(x1); + }; + + float scale_; +}; + +struct Subtract +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = x0 - x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x0); + const float x2_tmp = ck::type_convert(x1); + const float y_tmp = x1_tmp - x2_tmp; + y = ck::type_convert(y_tmp); + } + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = x0 - x1; + }; +}; + +struct Bilinear +{ + Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + y = alpha_ * x0 + beta_ * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + y = alpha_ * x0 + beta_ * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + y = type_convert(alpha_) * x0 + type_convert(beta_) * x1; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + y = type_convert(alpha_ * x0 + beta_ * ck::type_convert(x1)); + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float x0_tmp = type_convert(x0); + const float x1_tmp = type_convert(x1); + const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp; + y = type_convert(y_tmp); + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float x1_tmp = ck::type_convert(x1); + const float y_tmp = alpha_ * x0 + beta_ * x1_tmp; + y = y_tmp; + }; + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x0, const int8_t& x1) const + { + y = type_convert(alpha_ * type_convert(x0) + + beta_ * type_convert(x1)); + }; + + float alpha_; + float beta_; +}; + +struct AddClamp +{ + AddClamp(float floor = 0.f, float ceil = NumericLimits::Max()) + : floor_(floor), ceil_(ceil){}; + + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > type_convert(floor_) + ? (a < type_convert(ceil_) ? a : type_convert(ceil_)) + : type_convert(floor_); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + const float a = x0 + type_convert(x1); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float a = x0 + type_convert(x1); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float a = type_convert(x0) + type_convert(x1); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void + operator()(int& y, const int& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + const float floor_; + const float ceil_; +}; + +struct AddRelu +{ + template + __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = a > 0.0f ? a : 0.0f; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = a > 0.0 ? a : 0.0; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = a > type_convert(0.0f) ? a : type_convert(0.0f); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const float& x0, const half_t& x1) const + { + const float a = x0 + type_convert(x1); + const float b = a > 0.0f ? a : 0.0f; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const half_t& x1) const + { + const float a = x0 + type_convert(x1); + y = a > 0.0f ? a : 0.0f; + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const float& x0, const bhalf_t& x1) const + { + const float a = x0 + type_convert(x1); + const float b = a > 0.0f ? a : 0.0f; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const + { + const float a = type_convert(x0) + type_convert(x1); + const float b = a > 0.0f ? a : 0.0f; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void + operator()(int& y, const int& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > 0 ? a : 0; + }; + + template <> + __host__ __device__ constexpr void + operator()(int8_t& y, const int8_t& x0, const int8_t& x1) const + { + const int8_t a = x0 + x1; + y = a > 0 ? a : 0; + }; +}; + +struct AddHardswish +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + double a = x0 + x1; + double b = a + 3.0; + double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667; + y = c; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + float a = x0 + x1; + float b = a + 3.0f; + float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f; + y = c; + }; +}; + +// E = FastGelu(C + D) +struct AddFastGelu +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c + d; + + FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const + { + const half_t x = c + d; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c + d; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const + { + const float x0_f = type_convert(c) + type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +// E = MultiplyFastGelu(C + D) +struct MultiplyFastGelu +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c * d; + + FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const + { + const half_t x = c * d; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c * d; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const + { + const float x0_f = type_convert(c) * type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c * type_convert(d); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +// E = Silu(C + D) +struct AddSilu +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ constexpr void + operator()(float& e, const float& c, const float& d) const + { + const float x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const half_t& c, const half_t& d) const + { + const half_t x = c + d; + + Silu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void + operator()(half_t& e, const float& c, const half_t& d) const + { + const float x0_f = c + d; + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void + operator()(bhalf_t& e, const float& c, const bhalf_t& d) const + { + const float x0_f = c + type_convert(d); + + float x1_f = 0; + + Silu{}.template operator()(x1_f, x0_f); + + e = type_convert(x1_f); + } +}; + +struct ConvScaleAdd +{ + __host__ __device__ ConvScaleAdd(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c, const D& d) const; + + template <> + __host__ __device__ void + operator()(f8_t& e, const float& c, const float& d) const + { + float x; + Add{}.template operator()(x, c * scale_in_ * scale_wei_, d); + e = type_convert(x * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp new file mode 100755 index 000000000000..3cc1c3c42ccc --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// y = UnaryOp0(UnaryOp1(...(x))) +template +struct UnaryCombinedOp +{ + __host__ __device__ UnaryCombinedOp() : unary_ops_() {} + + __host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // Execute first unary op to copy data to y + unary_ops_.At(Number<0>{})(y, x); + + static_for<1, Tuple::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); }); + }; + + Tuple unary_ops_; +}; + +// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1)) +template +struct BinaryWithUnaryCombinedOp +{ + __host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {} + + __host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1) + : binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const + { + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result); + }; + + private: + BinaryOp binary_op_; + UnaryOp0 unary_op0_; + UnaryOp1 unary_op1_; +}; + +// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2)) +template +struct TrinaryWithUnaryCombinedOp +{ + __host__ __device__ TrinaryWithUnaryCombinedOp() + : binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_() + { + } + + __host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0, + BinaryOp0 binary_op1, + UnaryOp0 unary_op0, + UnaryOp1 unary_op1, + UnaryOp2 unary_op2) + : binary_op0_(binary_op0), + binary_op1_(binary_op1), + unary_op0_(unary_op0), + unary_op1_(unary_op1), + unary_op2_(unary_op2) + { + } + + template + __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const + { + + Y unary_x0_tmp_result; + Y unary_x1_tmp_result; + Y unary_x2_tmp_result; + unary_op0_(unary_x0_tmp_result, x0); + unary_op1_(unary_x1_tmp_result, x1); + unary_op2_(unary_x2_tmp_result, x2); + binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result); + binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result); + }; + + private: + BinaryOp0 binary_op0_{}; + BinaryOp1 binary_op1_{}; + UnaryOp0 unary_op0_{}; + UnaryOp1 unary_op1_{}; + UnaryOp2 unary_op2_{}; +}; + +using ScaleScalePass = UnaryCombinedOp; +using ScaleScaleRelu = UnaryCombinedOp; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp new file mode 100755 index 000000000000..b57ae22172ca --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp @@ -0,0 +1,588 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/quantization_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// Need to ensure compiler will fail if there is no matching candidate, instead of compiler +// siliently do implicit type conversion +// +// Example: +// +// struct ExampleElementwiseOp +// { +// template +// __host__ __device__ constexpr void +// operator()(Y&, const X) const; +// +// template<> +// __host__ __device__ constexpr void +// operator()(half_t& y, const half_t& x) const +// { +// } +// }; + +struct AddReluAdd +{ + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + half_t a = x0 + x1; + half_t b = a > 0 ? a : 0; + y = b + x2; + } + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const float& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const + { + float a = x0 + x1; + float b = a > 0 ? a : 0; + float c = b + x2; + y = c; + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& y, const int8_t& x0, const int8_t& x1, const int8_t& x2) const + { + int32_t a = x0 + x1; + int32_t b = a > 0 ? a : 0; + int32_t c = b + x2; + y = c; + } + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + __host__ __device__ constexpr void operator()( + int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const + { + int32_t a = x0 + x1; + int32_t b = a > 0 ? a : 0; + int32_t c = b + x2; + y = c; + } +#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +}; + +struct AddHardswishAdd +{ + template + __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x0, + const float& x1, + const float& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const + { + float a = x0 + x1; + float b = a + float{3}; + float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667}; + float d = c + x2; + y = d; + } +}; + +// C = A * B +// E = C + D0 + D1 +struct AddAdd +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const + { + // Only support floating so far + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + const C y = c + type_convert(d0) + type_convert(d1); + e = type_convert(y); + } +}; + +// C = A * B +// E = (C + D0) x D1 +struct AddMultiply +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ void operator()(half_t& e, + const half_t& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = (c + d0) * d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = (type_convert(c) + d0) * d1; + e = y; + } + template <> + __host__ __device__ void operator()(float& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const float y = (c + d0) * d1; + e = y; + } +}; + +// C = A * B +// E = C x D0 + D1 +struct MultiplyAdd +{ + template + __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ void operator()(half_t& e, + const half_t& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = (c * d0) + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const half_t y = type_convert(c) * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(bhalf_t& e, + const float& c, + const bhalf_t& d0, + const bhalf_t& d1) const + { + const bhalf_t y = type_convert(c) * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(float& e, + const float& c, + const half_t& d0, + const half_t& d1) const + { + const float y = c * d0 + d1; + e = y; + } + template <> + __host__ __device__ void operator()(half_t& e, + const float& c, + const float& d0, + const float& d1) const + { + const float y = c * d0 + d1; + e = y; + } +}; + +struct MultiplyMultiply +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const + { + const float x0_f = c * d0 * d1; + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::half_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const + { + const float x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +struct MultiplyAddFastGelu +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const + { + const float x0_f = c * ck::type_convert(d0) + ck::type_convert(d1); + + float x1_f = 0; + + FastGelu{}.template operator()(x1_f, x0_f); + + e = ck::type_convert(x1_f); + } +}; + +// E = FastGelu(C + D0 + D1) +struct AddAddFastGelu +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(float& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x = c + d0 + d1; + + FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const + { + const half_t x = c + d0 + d1; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(e, x); + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const float& c, const half_t& d0, const half_t& d1) const + { + const float x0_f = c + d0 + d1; + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const + { + const float x0_f = c + type_convert(d0) + type_convert(d1); + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const + { + const float x0_f = + type_convert(c) + type_convert(d0) + type_convert(d1); + + float x1_f = 0; + + ck::tensor_operation::element_wise::FastGelu{}.template operator()(x1_f, + x0_f); + + e = type_convert(x1_f); + } +}; + +// E = Relu(alpha1 * C + alpha2 * D0 + D1) +struct ScaleAddScaleAddRelu +{ + + ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f) + : alpha1_(alpha1), alpha2_(alpha2) + { + } + + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()(float& e, + const float& c, + const float& d0, + const float& d1) const + { + const float x = c * alpha1_ + alpha2_ * d0 + d1; + e = x > 0 ? x : 0; + } + + template <> + __host__ __device__ constexpr void operator()( + half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * type_convert(d0) + + type_convert(d1); + + float result = 0; + result = x > 0 ? x : 0; + + e = type_convert(result); + } + + template <> + __host__ __device__ constexpr void operator()( + bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * type_convert(d0) + + type_convert(d1); + + float result = 0; + result = x > 0 ? x : 0; + + e = type_convert(result); + } + + template <> + __host__ __device__ constexpr void operator()( + int8_t& e, const int8_t& c, const float& d0, const float& d1) const + { + const float x = type_convert(c) * alpha1_ + alpha2_ * d0 + d1; + + float result = 0; + result = x > 0 ? x : 0; + + e = type_convert(result); + } + + const float alpha1_; + const float alpha2_; +}; + +struct Normalize +{ + // FIXME: is double absolutely necessary? + Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {} + + template + __host__ __device__ constexpr void operator()(T1& y, + const T1& x, + const T2& mean, + const T2& mean_square, + const T3& gamma, + const T3& beta) const; + + template <> + __host__ __device__ constexpr void operator()(half_t& y, + const half_t& x, + const float& mean, + const float& mean_square, + const half_t& gamma, + const half_t& beta) const + { + using ck::math::sqrt; + + float variance = mean_square - (mean * mean); + + float tmp_x = type_convert(x); + float tmp_gamma = type_convert(gamma); + float tmp_beta = type_convert(beta); + + float tmp_y = + ((tmp_x - mean) / sqrt(variance + type_convert(epsilon_))) * tmp_gamma + + tmp_beta; + + y = type_convert(tmp_y); + }; + + template <> + __host__ __device__ constexpr void operator()(float& y, + const float& x, + const float& mean, + const float& mean_square, + const float& gamma, + const float& beta) const + { + using ck::math::sqrt; + + float variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + type_convert(epsilon_))) * gamma + beta; + }; + + template <> + __host__ __device__ constexpr void operator()(double& y, + const double& x, + const double& mean, + const double& mean_square, + const double& gamma, + const double& beta) const + { + using ck::math::sqrt; + + double variance = mean_square - (mean * mean); + y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta; + }; + + // FIXME: is double absolutely necessary? + double epsilon_; +}; + +// used by BatchNorm inference +// y = gamma * (x-mean) / sqrt(epsilon+variance) + beta +// The data type of mean and variance is used as AccDataType +struct NormalizeInInfer +{ + NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {} + + template + __host__ __device__ constexpr void operator()(T1& y, + const T1& x, + const T2& mean, + const T2& variance, + const T3& gamma, + const T4& beta) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + using ck::type_convert; + using ck::math::sqrt; + + T2 tmp_x, tmp_y; + + tmp_x = type_convert(x); + + tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert(epsilon_))) * + type_convert(gamma) + + type_convert(beta); + y = type_convert(tmp_y); + }; + + double epsilon_; +}; + +template +struct UnaryTypeConvert; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const + { + y = ck::type_convert(x); + } +}; + +template <> +struct UnaryTypeConvert +{ + __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const + { + y = ck::type_convert(x); + } +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/quantization_operation.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/quantization_operation.hpp new file mode 100755 index 000000000000..fefa6c793f7e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/quantization_operation.hpp @@ -0,0 +1,286 @@ +#pragma once + +#include "ck/utility/data_type.hpp" +// #include "ck/utility/get_id.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { + +// Y = Sy * Qy +// W = Sw * Qw +// X = Sx * Qx +// B = Sb * Qb = Sw * Sx * Qb +// Where X, W, Y are float32, Qx, Qw, Qy are int8 +// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range +// Qb is int32, scale of B is Sw * Sx for convenient + +// Y = W @ X, where @ is convolution or matrix multiplication +// Sy * Qy = Sw * Qw @ Sx * Qx +// Qy = [(Sw*Sx)/Sy] * Qw @ Qx + +// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) +template +struct Activation_Mul_Clamp +{ + // Convolution + Activation (piecewise linear function) + // If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy) + // Z = Activation(Y) = Activation(W @ X) + // Sz * Qz = Activation(Sy * Qy) + // Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx) + + // requantScale_ = Sw * Sx / Sz + Activation_Mul_Clamp(float requantScale, Activation activationOp) + : requantScale_(requantScale), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const + { + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __device__ constexpr void operator()(int32_t& y, const int32_t& x) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ constexpr void operator()(float& y, const float& x) const + { + // CAUSION - We might float in & float out in reference code + activationOp_(y, x); + y = math::clamp(requantScale_ * y, -128.f, 127.f); + } + + float requantScale_; + Activation activationOp_; +}; + +// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy * Qy) != Sy * Activation(Qy) +template +struct Mul_Activation_Mul_Clamp +{ + // Convolution + Activation (non piecewise linear function) + // Z = Activation(Y) = Activation(W @ X) + // Sz * Qz = Activation(Sy * Qy) + // Qz = S1 * Activation[Sacc * (Qw @ Qx)] + // Where S1 = 1 / Sz, Sacc = Sw * Sx + Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp) + : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void operator()(int8_t& y, const int32_t& x) const + { + float y_fp32 = ck::type_convert(x); + y_fp32 = scaleAcc_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; + float scaleAcc_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is piecewise linear function, such as +// relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) +template +struct Activation_Mul2_Clamp +{ + Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {} + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const float& requantScale) const + { + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const float& requantScale) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + Activation activationOp_; +}; + +// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc +// Activation(Sy * Qy) = Sy * Activation(Qy) +template +struct Add_Activation_Mul_Clamp +{ + // Convolution + bias + // Let Bias = B = Sw * Sx * Qb + // Where Qb is int32 + // Y = W @ X + B + // Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb + // Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb) + + // For activation, Z = Activaiton(Y) + // Sz * Qz = Activation(Sy * Qy) + // Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb) + Add_Activation_Mul_Clamp(float requantScale, Activation activationOp) + : requantScale_(requantScale), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias) const + { + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float requantScale_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is piecewise linear function, such as +// relu, leaky relu ...etc +template +struct Add_Activation_Mul2_Clamp +{ + Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {} + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const + { + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& requantScale) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(requantScale * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + Activation activationOp_; +}; + +// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy * Qy) != Sy * Activation(Qy) +template +struct Add_Mul_Activation_Mul_Clamp +{ + // Convolution + Activation (non piecewise linear function) + // Z = Activation(Y) = Activation(W @ X + B) + // Sz * Qz = Activation(Sy * Qy) + // Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)] + // Where S1 = 1 / Sz, Sacc = Sw * Sx + Add_Mul_Activation_Mul_Clamp(float scale_z_inv, float scaleAcc, Activation activationOp) + : scale_z_inv_(scale_z_inv), scaleAcc_(scaleAcc), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias) const + { + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc_ * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; + float scaleAcc_; + Activation activationOp_; +}; + +// Conv Perchannel quantization + Activation function which is non piecewise linear function, +// such as TanH, Sigmoid ...etc +// If an activation is not piecewise linear function +// then Activation(Sy *Qy) != Sy * Activation(Qy) +template +struct Add_Mul2_Activation_Mul_Clamp +{ + Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp) + : scale_z_inv_(scale_z_inv), activationOp_(activationOp) + { + } + + __host__ __device__ constexpr void + operator()(int8_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const + { + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + __host__ __device__ constexpr void + operator()(int32_t& y, const int32_t& x, const int32_t& bias, const float& scaleAcc) const + { + // CAUSION - We might type_convert to int8 in threadwise copy + // eg. GridwiseGemmDlMultipleD_km_kn_mn + float y_fp32 = ck::type_convert(x + bias); + y_fp32 = scaleAcc * y_fp32; + activationOp_(y_fp32, y_fp32); + y_fp32 = math::clamp(scale_z_inv_ * y_fp32, -128.f, 127.f); + y = ck::type_convert(y_fp32); + } + + float scale_z_inv_; + Activation activationOp_; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp new file mode 100755 index 000000000000..8f829496da1d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -0,0 +1,1811 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/amd_inline_asm.hpp" +#include + +namespace ck { + +// Fast int4x4 to half8_t data type conversion based on paper +// [Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production] +// (https://arxiv.org/abs/2211.10017) and implementation: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +// Convert lower part of packed int4 -> int4 to half +__device__ inline half4_t i4_to_half4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Extract the two int4 at low bit and create two fp16 number. + int lo = amd_assembly_and_or_b32(q, LO, EX); + // Extract the two int4 at hight bit and create two fp16 number. + int hi = amd_assembly_and_or_b32(q, HI, EX); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + vector_type res; + + // for two fp16 from lowbit, subtract 1032 to get correct fp16 value + res.template AsType()(Number<0>{}) = + amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); + + // for two fp16 from highbit, divide 16 and subtract 72 to get correct fp16 value + res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( + bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); + + return res.template AsType()[Number<0>{}]; +} + +__device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + + // Extract the two int4 at low bit and create two fp16 number. + int lo = amd_assembly_and_or_b32(q, LO, EX); + // Extract the two int4 at hight bit and create two fp16 number. + int hi = amd_assembly_and_or_b32(q, HI, EX); + + const int SUB = 0xE408E408; // half2 {-1032, -1032} + const int MUL = 0x2c002c00; // half2 {1 / 16, 1 / 16} + const int ADD = 0xd480d480; // half2 {-72, -72} + + vector_type res; + + res.template AsType()(Number<0>{}) = + amd_assembly_pk_add_f16(bit_cast(lo), bit_cast(SUB)); + + res.template AsType()(Number<1>{}) = amd_assembly_pk_fma_f16( + bit_cast(hi), bit_cast(MUL), bit_cast(ADD)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(res.template AsType()(Number<0>{})) + : "v"(res.template AsType()(Number<0>{})), "v"(scale)); + + asm volatile("v_pk_mul_f16 %0, %1, %2" + : "=v"(res.template AsType()(Number<1>{})) + : "v"(res.template AsType()(Number<1>{})), "v"(scale)); + + return res.template AsType()[Number<0>{}]; +} + +__device__ inline f8x4_t i4_to_f8x4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + + int lo = amd_assembly_and_b32(q, LO); + int hi = amd_assembly_and_b32(q, HI); + + float f32_0 = amd_assemble_cvt_f32_i4(lo); + float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16); + float f32_2 = amd_assemble_cvt_f32_i4(hi); + float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16); + + return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3); +} + +__device__ inline f8x8_t i4_to_fp8x8(int q) { return amd_assembly_i4_to_fp8x8(q); } + +__device__ inline bhalf4_t i4_to_bhalf4(int q) +{ + uint32_t i8s = (q & 0xf) | ((q & 0xf0) << 4) | ((q & 0xf00) << 8) | ((q & 0xf000) << 12); + + static constexpr uint32_t fp32_base = 0x4B000000; + + float fp32_intermediates[4]; + + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + fp32_intermediates_casted[0] = __byte_perm(i8s, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(i8s, fp32_base, 0x7651); + fp32_intermediates_casted[2] = __byte_perm(i8s, fp32_base, 0x7652); + fp32_intermediates_casted[3] = __byte_perm(i8s, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388616.f; + fp32_intermediates[1] -= 8388616.f; + fp32_intermediates[2] -= 8388616.f; + fp32_intermediates[3] -= 8388616.f; + + vector_type res; + res.template AsType()(Number<0>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[1], fp32_intermediates_casted[0], 0x7632)); + res.template AsType()(Number<1>{}) = bit_cast( + __byte_perm(fp32_intermediates_casted[3], fp32_intermediates_casted[2], 0x7632)); + + return res.template AsType()[Number<0>{}]; +} + +namespace tensor_operation { +namespace element_wise { + +struct PassThroughPack8 +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + __host__ __device__ constexpr void operator()(ck::half8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_half4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_half4(bit_cast(x) >> 8); + + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + + __host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + y = i4_to_fp8x8(bit_cast(x)); + +#else + // Added pk_i4_t to f8x2_fnuz_t conversion + vector_type dst; + vector_type dst_tmp; + vector_type src{x}; + + // pk_i4_t to float2_t conversion + dst_tmp.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + + dst_tmp.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + + dst_tmp.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + + dst_tmp.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + // float to f8_t conversion + dst.template AsType()(Number<0>{}) = + type_convert(dst_tmp.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(dst_tmp.template AsType()[Number<1>{}]); + + dst.template AsType()(Number<2>{}) = + type_convert(dst_tmp.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(dst_tmp.template AsType()[Number<3>{}]); + + dst.template AsType()(Number<4>{}) = + type_convert(dst_tmp.template AsType()[Number<4>{}]); + dst.template AsType()(Number<5>{}) = + type_convert(dst_tmp.template AsType()[Number<5>{}]); + + dst.template AsType()(Number<6>{}) = + type_convert(dst_tmp.template AsType()[Number<6>{}]); + dst.template AsType()(Number<7>{}) = + type_convert(dst_tmp.template AsType()[Number<7>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + + __host__ __device__ constexpr void operator()(ck::bhalf8_t& y, const ck::pk_i4x4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_bhalf4(bit_cast(x)); + result.template AsType()(Number<1>{}) = i4_to_bhalf4(bit_cast(x) >> 16); + + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + constexpr const static bool is_pack8_invocable = true; +}; + +struct DequantPack8 +{ + template + __host__ __device__ void operator()(Y& y, const X& x, const Z& z) const; + + __host__ __device__ constexpr void + operator()(ck::half8_t& y, const ck::pk_i4x4_t& x, const ck::half2_t& z) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + vector_type result; + + result.template AsType()(Number<0>{}) = i4_to_half4_scale(bit_cast(x), z); + result.template AsType()(Number<1>{}) = + i4_to_half4_scale(bit_cast(x) >> 8, z); + + y = result.template AsType()[Number<0>{}]; +#else + vector_type dst; + vector_type src{x}; + + dst.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + dst.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + y = dst.template AsType()[Number<0>{}]; +#endif + } + + constexpr const static bool is_pack8_invocable = true; +}; + +struct PassThroughPack2 +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + __host__ __device__ constexpr void operator()(half2_t& y, const f8x2_t& x) const + { + auto t = type_convert(x); + y = type_convert(t); + } + + __host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const + { +#if CK_USE_PK4_LAYOUT_SHUFFLE + uint8_t x_u8 = ck::bit_cast(x); + uint8_t x_l = (x_u8 & 0x0f) >> 0; + uint8_t x_h = (x_u8 & 0xf0) >> 4; + + auto l_f16 = ck::type_convert(x_l); + auto h_f16 = ck::type_convert(x_h); + + y = {l_f16, h_f16}; +#else + uint32_t t = ck::bit_cast(x); + y = ck::bit_cast(t); +#endif + } + + constexpr const static bool is_pack2_invocable = true; +}; + +struct PassThrough +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(pk_i4_t& y, const pk_i4_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(f4x2_pk_t& y, + const f4x2_pk_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(double& y, const double& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const double& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(double& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(half_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const int32_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(float& y, const int32_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(int32_t& y, const int32_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(float& y, const bhalf_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const half_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(float& y, const half_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(half_t& y, const int8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const int8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(uint8_t& y, const uint8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(int8_t& y, const int32_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(int32_t& y, const int8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(int8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(float& y, const int8_t& x) const + { + y = type_convert(x); + } + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + __host__ __device__ void operator()(int4_t& y, const int4_t& x) const + { + y = x; + } + template <> + __host__ __device__ void operator()(int4_t& y, const int& x) const + { + y = type_convert(x); + } +#endif + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const f8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(f8_t& y, const half_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const bf8_t& x) const + { + y = x; + } + + template <> + __host__ __device__ void operator()(float& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const float& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(half_t& y, const bf8_t& x) const + { + y = type_convert(x); + } + + template <> + __host__ __device__ void operator()(bf8_t& y, const half_t& x) const + { + y = type_convert(x); + } +}; + +struct UnaryConvert +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + y = type_convert(x); + } +}; + +struct ConvertBF16RTN +{ + // convert to bf16 using round to nearest (rtn) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value, "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = bf16_convert_rtn(x); + } +}; + +struct ConvertF8SR +{ + // convert to fp8 using stochastic rounding (SR) + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_sr(x); + } +}; + +struct ConvertF8RNE +{ + // convert to fp8 using rounding to nearest even + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + // check Y datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + // check X datatype + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = f8_convert_rne(x); + } +}; + +struct Scale +{ + __host__ __device__ Scale(float scale = 1.f) : scale_(scale) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + y = type_convert(type_convert(x) * scale_); + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = type_convert(scale_) * x; + }; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + const float x_tmp = type_convert(x); + const float y_tmp = scale_ * x_tmp; + y = type_convert(y_tmp); + }; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = scale_ * x; + }; + + template <> + __host__ __device__ void operator()(double& y, const double& x) const + { + y = scale_ * x; + }; + + template <> + __host__ __device__ void operator()(int8_t& y, const int8_t& x) const + { + y = type_convert(scale_ * type_convert(x)); + }; + + float scale_; +}; + +struct ScaleAndResetNaNToMinusInfinity +{ + __host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = math::isnan(x) ? -NumericLimits::Infinity() : scale_ * x; + }; + + float scale_; +}; + +struct UnaryDivide +{ + __host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {} + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = x / type_convert(divider_); + }; + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + float x_ = type_convert(x); + float divider_f_ = type_convert(divider_); + + y = type_convert(x_ / divider_f_); + }; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float x_ = type_convert(x); + float divider_f_ = type_convert(divider_); + + y = type_convert(x_ / divider_f_); + }; + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + float x_ = type_convert(x); + float divider_f_ = type_convert(divider_); + + y = type_convert(x_ / divider_f_); + }; + + int32_t divider_ = 1; +}; + +struct UnarySquare +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + || is_same_v +#endif + , + "Data type is not supported by this operation!"); + y = x * x; + }; +}; + +struct UnaryAbs +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::abs(x); + }; + + template <> + __host__ __device__ void operator()(f8_t& y, const f8_t& x) const + { + y = ck::type_convert(ck::math::abs(ck::type_convert(x))); + }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = ck::type_convert(ck::math::abs(x)); + }; +}; + +struct UnarySqrt +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + y = math::sqrt(x); + }; +}; + +struct Clamp +{ + Clamp(float floor = 0.f, float ceil = NumericLimits::Max()) + : floor_(floor), ceil_(ceil){}; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ constexpr void operator()(float& y, const float& x) const + { + const float& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void operator()(double& y, const double& x) const + { + const double& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void operator()(half_t& y, const half_t& x) const + { + const float a = type_convert(x); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(half_t& y, const float& x) const + { + const float& a = x; + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(bhalf_t& y, const float& x) const + { + const float& a = x; + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(bhalf_t& y, + const bhalf_t& x) const + { + const float a = type_convert(x); + const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + y = type_convert(b); + }; + + template <> + __host__ __device__ constexpr void operator()(int& y, const int& x) const + { + const int8_t& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + template <> + __host__ __device__ constexpr void operator()(int8_t& y, const int8_t& x) const + { + const int8_t& a = x; + y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_; + }; + + const float floor_; + const float ceil_; +}; + +struct Relu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + y = x > 0 ? x : 0; + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float x_f32 = type_convert(x); + float y_f32 = x_f32 > 0 ? x_f32 : 0; + y = type_convert(y_f32); + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + float y_f32 = x > 0 ? x : 0; + y = type_convert(y_f32); + }; +}; + +// Fast GeLU +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+tanh(sqrt(2/pi)*(x+0.044715*x^3))) +// host code use higher accuracy "exp" and "div" +// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function +struct FastGelu +{ + template + __host__ void operator()(Y& y, const X& x) const; + + template + __device__ void operator()(Y& y, const X& x) const; +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) + template <> + __host__ void operator()(float& y, const float& x) const + { + // const float u = -2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = exp(u); + y = x / (1.f + emu); + } +#endif + // device code, use lower precision "__ocml_exp_f32" and "rcp" + template <> + __device__ void operator()(float& y, const float& x) const + { + // const float u = 2.f * x * (0.035677f * x * x + 0.797885f); + const float c1 = -2.0 * 0.035677f; + const float c2 = -2.0 * 0.797885f; + const float u = x * (c1 * x * x + c2); + const float emu = __ocml_exp_f32(u); + + y = x * math::rcp(1.f + emu); + } + + template <> + __host__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(half_t& y, const half_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __host__ void operator()(half_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(half_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __host__ void operator()(bhalf_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(bhalf_t& y, const float& x) const + { + float y_f; + + this->operator()(y_f, x); + + y = type_convert(y_f); + } + + template <> + __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } + + template <> + __host__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float y_f; + + this->operator()(y_f, type_convert(x)); + + y = type_convert(y_f); + } +}; + +// https://paperswithcode.com/method/gelu +// y = 0.5*x*(1+erf(x/sqrt(2))) +struct Gelu +{ + template + __host__ __device__ void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(float& y, const float& x) const + { + y = 0.5f * x * (1.f + erf(float(0.70710678118f * x))); + } + + template <> + __host__ __device__ void operator()(half_t& y, const half_t& x) const + { + y = half_t(0.5) * x * (half_t(1) + half_t(erf(float(0.70710678118f * x)))); + } +}; + +struct Sigmoid +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = one / (one + math::exp(-x)); + }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(one / (one + math::exp(-x))); + }; +}; + +struct Silu +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v, + "Data type is not supported by this operation!"); + constexpr T one = type_convert(1); + y = x * (one / (one + math::exp(-x))); + }; +}; + +struct TanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::tanh(x); + }; + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(math::tanh(x)); + }; +}; + +struct ACos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::acos(x); + }; +}; + +struct Neg +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::neg(x); + }; +}; + +struct ATan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::atan(x); + }; +}; + +struct Sin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::sin(x); + }; +}; + +struct ASinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::asinh(x); + }; +}; + +struct Cos +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = cos(x); + }; +}; + +struct ACosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::acosh(x); + }; +}; + +struct Tan +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::tan(x); + }; +}; + +struct ATanH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::atanh(x); + }; +}; + +struct SinH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::sinh(x); + }; +}; + +struct Ceil +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::ceil(x); + }; +}; + +struct Exp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::exp(x); + }; +}; + +struct CosH +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::cosh(x); + }; +}; + +struct Floor +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::floor(x); + }; +}; + +struct Log +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::log(x); + }; +}; + +struct ASin +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::asin(x); + }; +}; + +struct Rcp +{ + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + + y = math::rcp(x); + }; +}; + +struct Swish +{ + Swish(float beta = 1.0f) : beta_(beta) {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Data type is not supported by this operation!"); + + float bx = -beta_ * type_convert(x); + y = type_convert(x / (1.f + math::exp(bx))); + }; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + float bx = -beta_ * x; + y = type_convert(x / (1.f + math::exp(bx))); + }; + + const float beta_; +}; + +struct SoftRelu +{ + SoftRelu(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = math::log(one + math::exp(x * casted_alpha)) / casted_alpha; + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(math::log(one + math::exp(x * alpha_)) / alpha_); + }; + const float alpha_; +}; + +struct Power +{ + Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f) + : alpha_(alpha), beta_(beta), gamma_(gamma){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + T casted_gamma = type_convert(gamma_); + T shifted_scaled_x = casted_alpha + casted_beta * x; + y = math::pow(shifted_scaled_x, casted_gamma); + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + const float shifted_scaled_x = alpha_ + beta_ * x; + y = type_convert(math::pow(shifted_scaled_x, gamma_)); + }; + + const float alpha_; + const float beta_; + const float gamma_; +}; + +struct ClippedRelu +{ + ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + T casted_beta = type_convert(beta_); + y = math::min(casted_beta, math::max(casted_alpha, x)); + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(math::min(beta_, math::max(alpha_, x))); + }; + + const float alpha_; + const float beta_; +}; + +struct LeakyRelu +{ + LeakyRelu(float alpha = 0.01f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x >= 0 ? x : x * casted_alpha; + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x >= 0 ? x : x * alpha_); + }; + + const float alpha_; +}; + +struct Elu +{ + Elu(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + y = x > 0 ? x : casted_alpha * math::expm1(x); + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + y = type_convert(x > 0 ? x : alpha_ * math::expm1(x)); + }; + + const float alpha_; +}; + +struct Logistic +{ + Logistic(float alpha = 1.f) : alpha_(alpha){}; + + template + __host__ __device__ void operator()(T& y, const T& x) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Data type is not supported by this operation!"); + T casted_alpha = type_convert(alpha_); + constexpr T one = type_convert(1); + y = casted_alpha / (one + ck::math::exp(-x) * casted_alpha); + } + + template + __host__ __device__ constexpr void operator()(Y& y, const X& x) const; + + template <> + __host__ __device__ void operator()(bhalf_t& y, const float& x) const + { + constexpr float one = 1.f; + y = type_convert(alpha_ / (one + ck::math::exp(-x) * alpha_)); + }; + const float alpha_; +}; + +struct ConvInvscale +{ + __host__ __device__ ConvInvscale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c / scale_in_ / scale_wei_ / scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScale +{ + __host__ __device__ ConvScale(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + e = type_convert(c * scale_in_ * scale_wei_ * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +struct ConvScaleRelu +{ + __host__ __device__ ConvScaleRelu(float scale_in = 1.f, + float scale_wei = 1.f, + float scale_out = 1.f) + : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out) + { + } + + template + __host__ __device__ void operator()(E& e, const C& c) const; + + template <> + __host__ __device__ void operator()(f8_t& e, const float& c) const + { + float x; + Relu{}.template operator()(x, c * scale_in_ * scale_wei_); + e = type_convert(x * scale_out_); + }; + + float scale_in_; + float scale_wei_; + float scale_out_; +}; + +// support fastconvert of int8 to fp16 + +template +struct FastNumericArrayConverter +{ +}; + +template <> +struct FastNumericArrayConverter +{ + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + OutputArray Output; + + uint32_t* half_2 = reinterpret_cast(&Output); + uint32_t const uint8_4 = reinterpret_cast(Input); + + static constexpr uint32_t byte_selector_01 = 0x05010500; + static constexpr uint32_t byte_selector_23 = 0x05030502; + static constexpr uint32_t fp16_adder = 0x64646464; + half_2[0] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_01); + half_2[1] = __builtin_amdgcn_perm(fp16_adder, uint8_4, byte_selector_23); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[0]) + : "v"(half_2[0]), "s"(I8s_TO_F16s_MAGIC_NUM)); + asm volatile("v_pk_add_f16 %0, %1, %2 neg_lo:[0,1] neg_hi:[0,1]" + : "=v"(half_2[1]) + : "v"(half_2[1]), "s"(I8s_TO_F16s_MAGIC_NUM)); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +template +struct FastNumericArrayConverter +{ + static constexpr int VEC_WIDTH = 4; + static_assert(!(N % VEC_WIDTH), "N must be multiple of 4."); + + using InputArray = vector_type; + using OutputArray = vector_type; + + __device__ static OutputArray convert(InputArray const& Input) + { + FastNumericArrayConverter converter; + + OutputArray Output; + + using Vec_InputArray = vector_type; + using Vec_OutputArray = vector_type; + + Vec_OutputArray* half_4_ptr = reinterpret_cast(&Output); + Vec_InputArray const* uint8_4_ptr = reinterpret_cast(&Input); + + static_for<0, N / VEC_WIDTH, 1>{}( + [&](auto i) { half_4_ptr[i] = converter(uint8_4_ptr[i]); }); + + return Output; + } + + __device__ OutputArray operator()(InputArray const& Input) { return convert(Input); } +}; + +struct DynamicUnaryOp +{ + __host__ __device__ DynamicUnaryOp() = delete; + + __host__ __device__ DynamicUnaryOp(const Swish& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} + { + } + + __host__ __device__ DynamicUnaryOp(const Swish&& swish) + : unary_op_type_(UnaryOpType::Swish), swish_{swish.beta_} + { + } + + __host__ __device__ DynamicUnaryOp(const Sigmoid&) : unary_op_type_(UnaryOpType::Sigmoid) {} + + __host__ __device__ DynamicUnaryOp(const Sigmoid&&) : unary_op_type_(UnaryOpType::Sigmoid) {} + + __host__ __device__ DynamicUnaryOp(const PassThrough&) + : unary_op_type_(UnaryOpType::PassThrough) + { + } + + __host__ __device__ DynamicUnaryOp(const PassThrough&&) + : unary_op_type_(UnaryOpType::PassThrough) + { + } + + __host__ __device__ DynamicUnaryOp(const Logistic& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const Logistic&& logistic) + : unary_op_type_(UnaryOpType::Logistic), logistic_{logistic.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const TanH&) : unary_op_type_(UnaryOpType::TanH) {} + + __host__ __device__ DynamicUnaryOp(const TanH&&) : unary_op_type_(UnaryOpType::TanH) {} + + __host__ __device__ DynamicUnaryOp(const Relu&) : unary_op_type_(UnaryOpType::Relu) {} + + __host__ __device__ DynamicUnaryOp(const Relu&&) : unary_op_type_(UnaryOpType::Relu) {} + + __host__ __device__ DynamicUnaryOp(const SoftRelu& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const SoftRelu&& softrelu) + : unary_op_type_(UnaryOpType::SoftRelu), soft_relu_{softrelu.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const UnaryAbs&) : unary_op_type_(UnaryOpType::UnaryAbs) {} + + __host__ __device__ DynamicUnaryOp(const UnaryAbs&&) : unary_op_type_(UnaryOpType::UnaryAbs) {} + + __host__ __device__ DynamicUnaryOp(const Power& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) + { + } + + __host__ __device__ DynamicUnaryOp(const Power&& pow) + : unary_op_type_(UnaryOpType::Power), power_(pow.alpha_, pow.beta_, pow.gamma_) + { + } + + __host__ __device__ DynamicUnaryOp(const ClippedRelu& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} + { + } + + __host__ __device__ DynamicUnaryOp(const ClippedRelu&& clippedrelu) + : unary_op_type_(UnaryOpType::ClippedRelu), + clipped_relu_{clippedrelu.alpha_, clippedrelu.beta_} + { + } + + __host__ __device__ DynamicUnaryOp(const LeakyRelu& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const LeakyRelu&& leakyrelu) + : unary_op_type_(UnaryOpType::LeakyRelu), leaky_relu_{leakyrelu.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const Elu& elu) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const Elu&& elu) + : unary_op_type_(UnaryOpType::Elu), elu_{elu.alpha_} + { + } + + __host__ __device__ DynamicUnaryOp(const DynamicUnaryOp& dynamic_op) = default; + + __host__ __device__ ~DynamicUnaryOp() {} + + template + __host__ __device__ void operator()(Y& y, const X& x) const + { + switch(unary_op_type_) + { + case(UnaryOpType::Swish): swish_(y, x); break; + case(UnaryOpType::Sigmoid): sigmoid_(y, x); break; + case(UnaryOpType::PassThrough): pass_through_(y, x); break; + case(UnaryOpType::Logistic): logistic_(y, x); break; + case(UnaryOpType::TanH): tanh_(y, x); break; + case(UnaryOpType::Relu): relu_(y, x); break; + case(UnaryOpType::SoftRelu): soft_relu_(y, x); break; + case(UnaryOpType::UnaryAbs): unary_abs_(y, x); break; + case(UnaryOpType::Power): power_(y, x); break; + case(UnaryOpType::ClippedRelu): clipped_relu_(y, x); break; + case(UnaryOpType::LeakyRelu): leaky_relu_(y, x); break; + case(UnaryOpType::Elu): elu_(y, x); break; + default: break; + } + } + + template <> + __host__ __device__ void operator()(bhalf_t& y, const bhalf_t& x) const + { + float y_float; + float x_float = type_convert(x); + this->operator()(y_float, x_float); + y = type_convert(y_float); + } + + private: + enum class UnaryOpType + { + Swish, + Sigmoid, + PassThrough, + Logistic, + TanH, + Relu, + SoftRelu, + UnaryAbs, + Power, + ClippedRelu, + LeakyRelu, + Elu + }; + + public: + UnaryOpType unary_op_type_; + + Swish swish_; + Sigmoid sigmoid_; + PassThrough pass_through_; + Logistic logistic_; + TanH tanh_; + Relu relu_; + SoftRelu soft_relu_; + UnaryAbs unary_abs_; + Power power_; + ClippedRelu clipped_relu_; + LeakyRelu leaky_relu_; + Elu elu_; +}; + +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp new file mode 100755 index 000000000000..7c9febf4de91 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp @@ -0,0 +1,704 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/workgroup_synchronization.hpp" + +namespace ck { + +template +__global__ void kernel_multiblock_batchnorm_forward( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count, + int32_t* const __restrict__ p_control, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + mean_var_count_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + epsilon, + p_x, + p_welford_mean, + p_welford_variance, + p_welford_count, + p_control, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseMultiblockBatchNormForward +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); + + using ThreadwiseWelford1 = + ThreadwiseWelford; + + using ThreadwiseWelford2 = + ThreadwiseWelfordMerge; + + using BlockwiseWelford1 = BlockwiseWelford; + + using BlockwiseWelford2 = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const __restrict__ p_welford_mean, + MeanVarDataType* const __restrict__ p_welford_variance, + int32_t* const __restrict__ p_welford_count, + int32_t* const __restrict__ p_control, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + { + using ck::math::sqrt; + + const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + if(block_local_id == 0) + gms_init(BlockSize / WarpSize * blkgroup_size, &p_control[blkgroup_id * 2]); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + StaticBuffer + x_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + StaticBuffer count_thread_buf; + + StaticBuffer + tmp_mean_thread_buf; + StaticBuffer + tmp_var_thread_buf; + StaticBuffer tmp_count_thread_buf; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + // Step 1: each workgroup does local welford reduction + + auto threadwise_welford_1 = ThreadwiseWelford1(); + threadwise_welford_1.max_count_ = + get_reduce_count_per_thread(block_local_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k); + threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + count_thread_buf(I) = threadwise_welford_1.cur_count_; + BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I)); + }); + + // Step 2: each workgroup writes its local welford result to workspace memory + + auto mean_global_val_buf = + make_dynamic_buffer( + p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto var_global_val_buf = + make_dynamic_buffer( + p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto count_global_val_buf = + make_dynamic_buffer( + p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto threadwise_mean_var_store_m_g = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto threadwise_count_store_m_g = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + if(thread_k_cluster_id == 0) + { + threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + mean_thread_buf, + mean_var_count_grid_desc_m_g, + mean_global_val_buf); + + threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + var_thread_buf, + mean_var_count_grid_desc_m_g, + var_global_val_buf); + + threadwise_count_store_m_g.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + count_thread_buf, + mean_var_count_grid_desc_m_g, + count_global_val_buf); + }; + + gms_barrier(&p_control[blkgroup_id * 2]); + + if(block_local_id == 0) + gms_reset(&p_control[blkgroup_id * 2]); + + // Step 3: each workgroup reads welford results from workspace memory and does final welford + // reduction + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + count_thread_buf(I) = 0; + }); + + constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize); + + int32_t reducedSize = 0; + while(reducedSize < blkgroup_size) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + tmp_count_thread_buf); + + ThreadwiseWelford2::Run(tmp_mean_thread_buf, + tmp_var_thread_buf, + tmp_count_thread_buf, + mean_thread_buf, + var_thread_buf, + count_thread_buf); + + reducedSize += KThreadClusterSize; + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_read_fwd_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_read_fwd_step_m_k); + }; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I)); + }); + + // Step 4: do normalization using the mean/variance + + StaticBuffer scale_thread_buf; + + StaticBuffer bias_thread_buf; + + StaticBuffer + y_thread_buf; + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + threadwise_x_load.SetSrcSliceOrigin( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[Number{}] / sqrt(var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[Number{}] - mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // normalize + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k); + } + + // Step 5: update the moving average of mean and variance (optional) + + if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 6: save mean and inv-variance (optional) + + if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; // namespace ck + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp new file mode 100755 index 000000000000..4e182ec29ddd --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp @@ -0,0 +1,498 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_second_half_batchnorm_backward_final( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const XYGridDesc_M_K dx_grid_desc_m_k, + const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + index_t blkgroup_size, + long_index_t reduce_size, + index_t num_xy_k_block_tile_iteration, + index_t num_dscale_dbias_k_block_tile_iteration, + const DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + const DscaleDbiasDataType* const __restrict__ p_reduce_dbias, + const MeanVarDataType* const __restrict__ p_mean, + const MeanVarDataType* const __restrict__ p_inv_var, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) +{ + GridwiseReduceSecondHalfBatchNormBackwardFinal_::Run(x_grid_desc_m_k, + dy_grid_desc_m_k, + dx_grid_desc_m_k, + dscale_dbias_grid_desc_m_k, + mean_var_grid_desc_m, + scale_grid_desc_m, + bias_grid_desc_m, + blkgroup_size, + reduce_size, + num_xy_k_block_tile_iteration, + num_dscale_dbias_k_block_tile_iteration, + p_reduce_dscale, + p_reduce_dbias, + p_mean, + p_inv_var, + p_x, + p_dy, + p_scale, + dy_elementwise_op, + p_dx, + p_dscale, + p_dbias); +}; + +template +struct GridwiseReduceSecondHalfBatchNormBackwardFinal +{ + static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0 && + MThreadSliceSize % DxDstVectorSize == 0) || + (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0 && + KThreadSliceSize % DxDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // Two of the steps of Multiblock BatchNorm Backward + // Step 1: Second half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // Step 2: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly + // clang-format on + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& dy_grid_desc_m_k, + const XYGridDesc_M_K& dx_grid_desc_m_k, + const DscaleDbiasGridDesc_M_K& dscale_dbias_grid_desc_m_k, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& dscale_dbias_grid_desc_m, + index_t blkgroup_size, + long_index_t reduce_size, + index_t num_xy_k_block_tile_iteration, + index_t num_dscale_dbias_k_block_tile_iteration, + const DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + const DscaleDbiasDataType* const __restrict__ p_reduce_dbias, + const MeanVarDataType* const __restrict__ p_mean, + const MeanVarDataType* const __restrict__ p_inv_var, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) + { + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + reduce_dscale_thread_buf; + StaticBuffer + reduce_dbias_thread_buf; + + StaticBuffer dscale_thread_buf; + StaticBuffer dbias_thread_buf; + + StaticBuffer + x_thread_buf; + StaticBuffer + dy_thread_buf; + StaticBuffer + dx_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer + inv_var_thread_buf; + StaticBuffer scale_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + // clang-format off + // Step 1: do final reduction of dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // clang-format on + + auto threadwise_dscale_dbias_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + dscale_dbias_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_dscale_dbias_store_m = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DscaleDbiasDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dscale_dbias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + const auto reduce_dscale_global_buf = make_dynamic_buffer( + p_reduce_dscale, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); + + const auto reduce_dbias_global_buf = make_dynamic_buffer( + p_reduce_dbias, dscale_dbias_grid_desc_m_k.GetElementSpaceSize()); + + auto dscale_global_buf = make_dynamic_buffer( + p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + auto dbias_global_buf = make_dynamic_buffer( + p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + constexpr auto dscale_dbias_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize * 1); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dscale_thread_buf(I) = type_convert(0.0f); + dbias_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_dscale_dbias_k_block_tile_iteration; + ++reducedTiles) + { + threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, + reduce_dscale_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dscale_thread_buf); + + threadwise_dscale_dbias_load_m_k.Run(dscale_dbias_grid_desc_m_k, + reduce_dbias_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dbias_thread_buf); + + ThreadwiseReduce::Reduce(reduce_dscale_thread_buf, dscale_thread_buf); + ThreadwiseReduce::Reduce(reduce_dbias_thread_buf, dbias_thread_buf); + + threadwise_dscale_dbias_load_m_k.MoveSrcSliceWindow(dscale_dbias_grid_desc_m_k, + dscale_dbias_thread_copy_step_m_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I)); + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); + }); + + threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m, + make_tuple(I0), + dscale_thread_buf, + dscale_dbias_grid_desc_m, + dscale_global_buf); + + threadwise_dscale_dbias_store_m.Run(thread_buffer_desc_m, + make_tuple(I0), + dbias_thread_buf, + dscale_dbias_grid_desc_m, + dbias_global_buf); + + // clang-format off + // Step 2: calculate dx = 1/N * inv-variance * scale * (N * dy - dbias - dscale * (x - mean) * inv-variance) + // clang-format on + + const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto x_global_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto dy_global_buf = make_dynamic_buffer( + p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_buf = make_dynamic_buffer( + p_dx, dx_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto mean_global_buf = make_dynamic_buffer( + p_mean, mean_var_grid_desc_m.GetElementSpaceSize()); + + const auto inv_var_global_buf = make_dynamic_buffer( + p_inv_var, mean_var_grid_desc_m.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + inv_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf); + + constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); + + AccDataType inv_reduce_size = + type_convert(1.0) / type_convert(reduce_size); + + for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM]; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + AccDataType tmpVal = norm_x * dscale_thread_buf[iM]; + + dx_thread_buf(Number{}) = + multiplier * + (type_convert(reduce_size) * dy_thread_buf[Number{}] - + dbias_thread_buf[iM] - tmpVal); + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, xy_thread_copy_step_m_k); + } + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp new file mode 100755 index 000000000000..a82a173500d2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp @@ -0,0 +1,261 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_multiblock_welford_first_half( + const XGridDesc_M_K x_grid_desc_m_k, + const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const p_welford_mean, + MeanVarDataType* const p_welford_variance, + int32_t* const p_welford_count) +{ + GridwiseMultiblockWelfordFirstHalf_::Run(x_grid_desc_m_k, + mean_var_count_grid_desc_m_g, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + p_x, + p_welford_mean, + p_welford_variance, + p_welford_count); +}; + +template +struct GridwiseMultiblockWelfordFirstHalf +{ + static_assert((XSrcCountSrcVectorDim == 0 && MThreadSliceSize % XSrcCountSrcVectorSize == 0) || + (XSrcCountSrcVectorDim == 1 && + KThreadSliceSize % XSrcCountSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcCountSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // First half of the Multiblock Welford method to calculate mean and variance, used by both batchnorm-forward and batchnorm-backward. + // clang-format on + __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, + const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x, + MeanVarDataType* const p_welford_mean, + MeanVarDataType* const p_welford_variance, + int32_t* const p_welford_count) + { + StaticBuffer + x_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_welford_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto threadwise_welford_count_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_count_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + auto welford_mean_global_val_buf = make_dynamic_buffer( + p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto welford_var_global_val_buf = make_dynamic_buffer( + p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto welford_count_global_val_buf = make_dynamic_buffer( + p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = + get_reduce_count_per_thread(block_local_id, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, welford_mean_thread_buf, welford_var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + welford_count_thread_buf(I) = threadwise_welford.cur_count_; + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_mean_thread_buf, + mean_var_count_grid_desc_m_g, + welford_mean_global_val_buf); + + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_var_thread_buf, + mean_var_count_grid_desc_m_g, + welford_var_global_val_buf); + + threadwise_welford_count_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + welford_count_thread_buf, + mean_var_count_grid_desc_m_g, + welford_count_global_val_buf); + }; + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp new file mode 100755 index 000000000000..672be91a79ed --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp @@ -0,0 +1,570 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_welford_second_half_batchnorm_forward_final( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + AccDataType epsilon, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseWelfordSecondHalfBatchNormForwardFinal_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + mean_var_count_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + blkgroup_size, + num_xy_k_block_tile_iteration, + epsilon, + p_in_welford_mean, + p_in_welford_variance, + p_in_welford_count, + p_x, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseWelfordSecondHalfBatchNormForwardFinal +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + AccDataType epsilon, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + + { + using ck::math::sqrt; + + StaticBuffer + in_welford_mean_thread_buf; + StaticBuffer + in_welford_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + StaticBuffer + x_thread_buf; + StaticBuffer + y_thread_buf; + + StaticBuffer scale_thread_buf; + StaticBuffer bias_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 0, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + const auto welford_mean_global_val_buf = make_dynamic_buffer( + p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_var_global_val_buf = make_dynamic_buffer( + p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + // Step 1: do final welford reduction to get mean and variance + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + constexpr auto mean_var_count_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize); + + int32_t reducedSize = 0; + while(reducedSize < blkgroup_size) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_welford_mean_thread_buf, + in_welford_var_thread_buf, + in_welford_count_thread_buf, + welford_mean_thread_buf, + welford_var_thread_buf, + welford_count_thread_buf); + + reducedSize += KThreadClusterSize; + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + // Step 2: do normalization and output y + + const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t workTiles = 0; workTiles < num_xy_k_block_tile_iteration; ++workTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[iM] / sqrt(welford_var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[iM] - welford_mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_thread_copy_step_m_k); + } + + // Step 3: update the moving average of mean and variance (optional) + + if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load_m = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load_m.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load_m.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + welford_mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + welford_var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 4: save mean and inv-variance (optional) + + if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + // calculate inv-variance as 1/sqrt(epsilon+variance) + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + welford_mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + welford_var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp new file mode 100755 index 000000000000..2d5dc90bfb37 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp @@ -0,0 +1,556 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_welford_second_half_reduce_first_half( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, + const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const DyElementwiseOp dy_elementwise_op, + MeanVarDataType* const __restrict__ p_out_welford_mean, + MeanVarDataType* const __restrict__ p_out_welford_inv_variance, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + DscaleDbiasDataType* const __restrict__ p_reduce_dbias) +{ + GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k, + dy_grid_desc_m_k, + mean_var_grid_desc_m, + mean_var_count_grid_desc_m_k, + dscale_dbias_grid_desc_m_g, + blkgroup_size, + num_xy_k_block_tile_iteration, + num_mean_var_count_k_block_tile_iteration, + epsilon, + haveSavedMeanInvVar, + p_savedMean, + p_savedInvVar, + p_in_welford_mean, + p_in_welford_variance, + p_in_welford_count, + dy_elementwise_op, + p_out_welford_mean, + p_out_welford_inv_variance, + p_x, + p_dy, + p_reduce_dscale, + p_reduce_dbias); +}; + +template +struct GridwiseWelfordSecondHalfReduceFirstHalf +{ + static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0) || + (XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceSrcDesc_M_1 = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // Two of the steps of Multiblock BatchNorm Backward + // Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance) + // Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // clang-format on + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& dy_grid_desc_m_k, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k, + const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g, + index_t blkgroup_size, + index_t num_xy_k_block_tile_iteration, + index_t num_mean_var_count_k_block_tile_iteration, + AccDataType epsilon, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const MeanVarDataType* const __restrict__ p_in_welford_mean, + const MeanVarDataType* const __restrict__ p_in_welford_variance, + const int32_t* const __restrict__ p_in_welford_count, + const DyElementwiseOp dy_elementwise_op, + MeanVarDataType* const __restrict__ p_out_welford_mean, + MeanVarDataType* const __restrict__ p_out_welford_inv_variance, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + DscaleDbiasDataType* const __restrict__ p_reduce_dscale, + DscaleDbiasDataType* const __restrict__ p_reduce_dbias) + { + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_welford_mean_thread_buf; + StaticBuffer + in_welford_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + StaticBuffer& mean_thread_buf = + welford_mean_thread_buf; + StaticBuffer& + inv_var_thread_buf = welford_var_thread_buf; + + StaticBuffer + x_thread_buf; + StaticBuffer + dy_thread_buf; + + // buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction + StaticBuffer + tmp1_thread_buf; + + StaticBuffer + reduce_dscale_thread_buf; + StaticBuffer + reduce_dbias_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / blkgroup_size; + const index_t block_local_id = block_global_id % blkgroup_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + using ThreadBufferLengths_M_1 = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + // clang-format off + // Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance) + // clang-format on + + if(haveSavedMeanInvVar) + { + const auto mean_global_buf = make_dynamic_buffer( + p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + const auto inv_var_global_buf = make_dynamic_buffer( + p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_inv_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + inv_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf); + } + else + { + const auto welford_mean_global_buf = make_dynamic_buffer( + p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_var_global_buf = make_dynamic_buffer( + p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + const auto welford_count_global_buf = make_dynamic_buffer( + p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize()); + + auto threadwise_mean_var_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + auto threadwise_count_load_m_k = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_count_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * 1)); + + constexpr auto mean_var_count_thread_copy_step_m_k = + make_multi_index(0, KThreadClusterSize * 1); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration; + ++reducedTiles) + { + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_mean_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_mean_thread_buf); + + threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_var_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_var_thread_buf); + + threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k, + welford_count_global_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_welford_mean_thread_buf, + in_welford_var_thread_buf, + in_welford_count_thread_buf, + welford_mean_thread_buf, + welford_var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_k.MoveSrcSliceWindow( + mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k); + threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k, + mean_var_count_thread_copy_step_m_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run(welford_mean_thread_buf(I), + welford_var_thread_buf(I), + welford_count_thread_buf(I)); + }); + + // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_var_thread_buf(I) = + type_convert(1.0) / sqrt(welford_var_thread_buf[I] + epsilon); + }); + + if(block_local_id == 0 && thread_k_cluster_id == 0) + { + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto mean_global_buf = make_dynamic_buffer( + p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto inv_var_global_buf = make_dynamic_buffer( + p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize()); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf, + mean_var_grid_desc_m, + inv_var_global_buf); + }; + }; + + const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + workSizePerBlock * block_local_id + + thread_k_cluster_id * KThreadSliceSize)); + + const auto x_global_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto dy_global_buf = make_dynamic_buffer( + p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); + + constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + reduce_dscale_thread_buf(I) = type_convert(0); + reduce_dbias_thread_buf(I) = type_convert(0); + }); + + // clang-format off + // Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance) + // clang-format on + + for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; + }); + }); + + ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf); + ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k); + }; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I)); + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I)); + }); + + auto threadwise_dscale_dbias_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + dscale_dbias_grid_desc_m_g, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_local_id), + PassThroughOp{}); + + auto reduce_dscale_global_buf = make_dynamic_buffer( + p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize()); + + auto reduce_dbias_global_buf = make_dynamic_buffer( + p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize()); + + if(thread_k_cluster_id == 0) + { + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dscale_thread_buf, + dscale_dbias_grid_desc_m_g, + reduce_dscale_global_buf); + + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + reduce_dbias_thread_buf, + dscale_dbias_grid_desc_m_g, + reduce_dbias_global_buf); + }; + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp new file mode 100755 index 000000000000..7eca68bbf88f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp @@ -0,0 +1,1784 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#include +#endif + +namespace ck { + +// Rows of column-vectors +template +struct BlockToCTileMap_M00_N0_M01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01() = default; + + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1) + : M01_(M01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + + const index_t grid_size = M00 * M01_ * N0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + if(M0 % M01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + + const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), + make_unmerge_transform(make_tuple(M00, M01)), + make_pass_through_transform(make_tuple(N0))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + + const auto cblockid_to_m00_n0_m01_block_cluster_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N0, M01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_n0_m01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_n0_m01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1)); + UnderlyingMap underlying_map_; +}; + +// Rows of column-vectors +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_M00_N0_M01Adapt; + +template +struct BlockToCTileMap_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt() = default; + + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + const BlockToCTileMap_M00_N0_M01Adapt&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt( + BlockToCTileMap_M00_N0_M01Adapt&&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& + operator=(const BlockToCTileMap_M00_N0_M01Adapt&) = default; + __host__ __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt& + operator=(BlockToCTileMap_M00_N0_M01Adapt&&) = default; + + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(index_t M, index_t N, index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, M01 = %d\n", M_, N_, M01_); + } +#endif + } + + template + __host__ + __device__ constexpr BlockToCTileMap_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8) + : BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01) + { + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ constexpr bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t M01_; +}; + +// keep the redundant type argument for backward compatibility +template +struct BlockToCTileMap_M00_N0_M01Adapt : BlockToCTileMap_M00_N0_M01Adapt +{ + using BlockToCTileMap_M00_N0_M01Adapt:: + BlockToCTileMap_M00_N0_M01Adapt; +}; + +// Grouped Rows of column-vectors WGP mapping +// Optimized for gfx94x-like multipe-die chip + +template +struct BlockToCTileMap_Grouped_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt() = default; + __host__ __device__ BlockToCTileMap_Grouped_M00_N0_M01Adapt(index_t M, + index_t N, + index_t M01 = 8) + : M_(M), N_(N), M01_(M01) + { + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + if(M0 == 1) + { + return make_tuple(0, block_1d_id); + } + else if(N0 == 1) + { + return make_tuple(block_1d_id, 0); + } + // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + else + { + const auto group_size = math::integer_divide_ceil(M0 * N0, GroupNum); + const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0); + auto group_id_x = block_1d_id % GroupNum; + auto group_id_y = block_1d_id / GroupNum; + auto remap_block_1d_id = + group_id_x <= big_group_num + ? group_id_x * group_size + group_id_y + : group_id_x * group_size + big_group_num - group_id_x + group_id_y; + + index_t idx_N0 = remap_block_1d_id % N0; + index_t idx_M0 = remap_block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + /** + * idxN0 + * + * |< mtx N >| + * + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * - |-----------|-----------|-----------|-----|-----|- + * ^ | - - 0 |/----> 2 | | | | + * | | | / | | | | | M_0 MPerBlock + * | M | /| | | | | | + * |-0---|---/-|-----|-----|-----------|-----|-----|- + * | 1 | / | | | blockid | | | + * idxM0 | | | / | V | 5 | | | M_1 MPerBlock + * | - V 1 | - 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | | | | | + * | | | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * Example: + * assume: + * M0 = 5 + * N0 = 4 + * block_1d_id = 5 + * M01 = 2 + * + * idx_N0 = 1 + * idx_M0 = 1 + * M01_adapt = 2 + * idx_M00 = 0 + * idx_M01 = 1 + * idx_N0_M01_local = 5 + * output {1, 2} + */ + + return make_tuple(idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t M01_; +}; + +// columns of row-vectors +// This C-tile map dynamically adjusts N01 when C-tile index is out of range +template +struct BlockToCTileMap_N00_M0_N01Adapt; + +template +struct BlockToCTileMap_N00_M0_N01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt() = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const BlockToCTileMap_N00_M0_N01Adapt&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(BlockToCTileMap_N00_M0_N01Adapt&&) = + default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(const BlockToCTileMap_N00_M0_N01Adapt&) = default; + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt& + operator=(BlockToCTileMap_N00_M0_N01Adapt&&) = default; + + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(index_t M, index_t N, index_t N01 = 8) + : M_(M), N_(N), N01_(N01) + { +#if 0 + if(get_thread_global_1d_id()==0){ + printf("Ctor called, M= %d, N= %d, N01 = %d\n", M_, N_, N01_); + } +#endif + } + + template + __host__ __device__ BlockToCTileMap_N00_M0_N01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t N01 = 8) + : BlockToCTileMap_N00_M0_N01Adapt( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), N01) + { + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + + template + __host__ static constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0); // swallow batch index + + index_t idx_M0 = block_1d_id % M0; + index_t idx_N0 = block_1d_id / M0; + + const auto N01_adapt = (idx_N0 < N0 - N0 % N01_) ? N01_ : N0 % N01_; + + index_t idx_N00 = idx_N0 / N01_; + index_t idx_N01 = idx_N0 % N01_; + index_t idx_M0_N01_local = idx_M0 + idx_N01 * M0; + + /** + * idxN0 + * + * |< mtx N >| + * + * |<---N01--->| + * - |-----------|-----------|-----------|-----|-----|- + * ^ | 0 ----------> 1 | | | | + * | | / | | | | M_0 MPerBlock + * | / | | | | + * |------/----------------|-----------|-----|-----|- + * | | | | | | | + * idxM0 | V | | | | | M_1 MPerBlock + * | 2 ----------> 3 | | | | + * |-----------|-----------|-----------|-----|-----|- + * mtx M | | blockid | | | | + * | | 5 | | | | M_2 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * | | | | | | + * | | | | | | M_3 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * V | | | | | | + * - |-----------|-----------|-----------|-----|-----|- M_4 MPerBlock + * | | | | | | + * |-----------|-----------|-----------|-----|-----|- + * NPerBlock NPerBlock NPerBlock NPerBlock + * N_0 N_1 N_2 N_3 + * Example: + * assume: + * N0 = 5 + * M0 = 4 + * block_1d_id = 5 + * N01 = 2 + * + * idx_M0 = 1 + * idx_N0 = 1 + * N01_adapt = 2 + * idx_N00 = 0 + * idx_N01 = 1 + * idx_M0_N01_local = 5 + * output {2, 1} + */ + + return make_tuple(idx_M0_N01_local / N01_adapt, + idx_M0_N01_local % N01_adapt + idx_N00 * N01_); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t N01_; +}; + +// 2D slices of column-vectors in 3D space +// This C-tile map dynamically adjusts M01 when C-tile index is out of range +template +struct BlockToCTileMap_KSplit_M00_N0_M01Adapt +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt() = default; + + __host__ __device__ BlockToCTileMap_KSplit_M00_N0_M01Adapt(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 8, + index_t KSplit = 1) + : M01_(M01), KSplit_(KSplit), c_grid_desc_m_n_(c_grid_desc_m_n) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const index_t grid_size = M0 * N0 * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n_.GetLength(I1), NPerBlock); + + block_1d_id = block_1d_id % (M0 * N0 * KSplit_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + private: + index_t M01_; + index_t KSplit_; + CGridDesc_M_N c_grid_desc_m_n_; +}; + +// Blocks of row-vectors +template +struct BlockToCTileMap_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01() = default; + + __host__ __device__ BlockToCTileMap_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1) + : M01_(M01), N01_(N01), underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01)) + { + } + + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return underlying_map_.CalculateBottomIndex(idx_top); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __host__ __device__ static constexpr auto + GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(1, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + index_t M01_, N01_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1)); + UnderlyingMap underlying_map_; +}; + +// 2D slices of row-vectors in 3D space +template +struct BlockToCTileMap_KSplit_M00_N00_M01_N01 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01() = default; + + __host__ BlockToCTileMap_KSplit_M00_N00_M01_N01(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01 = 1, + index_t N01 = 1, + index_t KSplit = 1) + : c_grid_desc_m_n_(c_grid_desc_m_n), + M01_(M01), + N01_(N01), + KSplit_(KSplit), + underlying_map_(GetBlockToCTileMap(c_grid_desc_m_n, M01, N01, KSplit)) + { + } + + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01_); + const auto N00 = math::integer_divide_ceil(N0, N01_); + + const index_t grid_size = M00 * M01_ * N00 * N01_ * KSplit_; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == 1); + + return underlying_map_.CalculateBottomIndex( + make_multi_index(idx_top[I0] % CalculateGridSize())); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + if constexpr(DeviceCTileIndexCheck) + return DefaultValidCTileIndex(c_tile_idx, c_tile_dim); + else + return true; + } + + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + if constexpr(DeviceCTileIndexCheck) + return true; // validity check moved to kernel + + const index_t M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const index_t N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + if(M0 % M01_ == 0 && N0 % N01_ == 0) + { + return true; + } + else + { + return false; + } + } + + private: + __device__ constexpr index_t CalculateGridSize() const + { + return CalculateGridSize(c_grid_desc_m_n_); + } + + __host__ static constexpr auto GetBlockToCTileMap(const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01, + index_t KSplit) + { + const auto M0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I0), MPerBlock); + const auto N0 = math::integer_divide_ceil(c_grid_desc_m_n.GetLength(I1), NPerBlock); + + const auto M00 = math::integer_divide_ceil(M0, M01); + const auto N00 = math::integer_divide_ceil(N0, N01); + + const auto ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KSplit), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KSplit, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_ksplit_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(ksplit_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_ksplit_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_ksplit_m0_n0_block_cluster_adaptor; + } + + CGridDesc_M_N c_grid_desc_m_n_; + index_t M01_, N01_, KSplit_; + using UnderlyingMap = decltype(GetBlockToCTileMap(CGridDesc_M_N{}, 1, 1, 1)); + UnderlyingMap underlying_map_; +}; + +template +__host__ __device__ bool DefaultValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) +{ + bool is_valid = false; + + const index_t m_block = c_tile_dim[Number<0>{}]; + const index_t n_block = c_tile_dim[Number<1>{}]; + + if constexpr(CTileIdx::Size() == 2) + { + const index_t m_block_idx = c_tile_idx[Number<0>{}]; + const index_t n_block_idx = c_tile_idx[Number<1>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + } + else if constexpr(CTileIdx::Size() == 3) + { + const index_t ksplit_idx = c_tile_idx[Number<0>{}]; + const index_t m_block_idx = c_tile_idx[Number<1>{}]; + const index_t n_block_idx = c_tile_idx[Number<2>{}]; + if(0 <= m_block_idx && m_block_idx < m_block && 0 <= n_block_idx && n_block_idx < n_block) + { + is_valid = true; + } + ignore = ksplit_idx; + } + + return is_valid; +} + +// This wrapper class is for grouped gemm where it subtracts blockIdx by a value so that the +// workgroups assigned to a given gemm problem have top index offsetted to range [0, +// grid_size_per_gemm] +template +struct OffsettedBlockToCTileMap +{ + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMap() = default; + __host__ __device__ OffsettedBlockToCTileMap(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t block_start) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; +}; +// second version with 2 offsets +template +struct OffsettedBlockToCTileMap2 +{ + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMap2(UnderlyingBlockToCTileMap block_to_ctile_map, + index_t group_offset, + index_t tile_offset) + : block_to_ctile_map_{block_to_ctile_map}, + group_offset_{group_offset}, + tile_offset_{tile_offset} + { + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + return block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] + tile_offset_ - group_offset_)); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + return block_to_ctile_map_.CalculateGridSize(M, N); + } + + __device__ void UpdateTileOffset(index_t offset) { tile_offset_ = offset; } + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t group_offset_; + index_t tile_offset_; +}; + +/** + * @brief Simple tile mapping which creates 3D grid of block of threads. + * + * @paragraph Description + * This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread + * blocks. The first 2D are regular 2D tiles created by division of output GEMM + * dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension, + * which denotes the number of blocks we use to divide work on GEMM K dimension onto. + * + * @tparam MPerBlock Output block tile size in M dimension. + * @tparam NPerBlock Output block tile size in N dimension. + */ +template +struct BlockToCTileMap_3DGrid_KSplit +{ + + __host__ __device__ BlockToCTileMap_3DGrid_KSplit() = default; + + __host__ __device__ constexpr auto + CalculateGridSize(index_t M, index_t N, index_t k_split) const + { + // Create 3D grid + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + return make_tuple(N0, M0, k_split); + } + + template + __device__ constexpr auto CalculateBottomIndex(const TopIdx&) const + { + return make_tuple(blockIdx.z, blockIdx.y, blockIdx.x); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + template + __host__ constexpr bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } +}; + +enum StreamKReductionStrategy +{ + Atomic = 0, // sk block use atomic to do reduction + Reduction, // let some workgroup responsible for doing the reduction operation +}; + +template +struct BlockToCTileMap_GemmStreamK +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategy_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv eqav_tiles_big; // for reduction + MDiv eqav_tiles_little; // for reduction + + // MDiv tile_swizzle_sub_m_rem; + //-------------------------------------- + + // prefer construct on host + BlockToCTileMap_GemmStreamK(uint32_t m, + uint32_t n, + uint32_t k, + uint32_t num_cu, + uint32_t occupancy, + uint32_t sk_blocks = 0xffffffff) + { + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + // one cu can hold one wg at one time, from the whole chip's point of view + // if number of wg is same as num_cu, we call it 1 dispatch + // if number of wg is 2x num_cu, we call it 2 dispatches. + // one dispatch can deliver wg same as num_cu (full dispatch), or less than num_cu (partial + // dispatch) + // + uint32_t full_dispatches = num_tiles / num_cu; + uint32_t full_dispatch_tiles = full_dispatches * num_cu; + uint32_t partial_dispatche_tiles = num_tiles - full_dispatch_tiles; + + uint32_t sk_occupancy = occupancy; + uint32_t dp_tiles = full_dispatch_tiles; + uint32_t sk_tiles = partial_dispatche_tiles; + + if(full_dispatches < occupancy) + { + // in this case, we allocate all blocks as sk blocks + // sk_occupancy = occupancy - full_dispatches; + sk_occupancy = 1; // TODO: single occ seems better + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else if((occupancy > 1) && (full_dispatches % occupancy == occupancy - 1)) + { + // e.g. occupancy = 2, full_dispatches = 3, 5, 7 ... + // occupancy = 3, full_dispatches = 5, 8, 11 ... + // occupancy = 4, full_dispatches = 7, 11 ... + sk_occupancy = 1; // left 1 slot for sk occupancy + dp_tiles = full_dispatch_tiles; + sk_tiles = partial_dispatche_tiles; + } + else + { + // others, we reduce 1 dispatch from dp, together with partial dispatch, + // to construct sk dispatch + sk_occupancy = occupancy - ((full_dispatches - 1) % occupancy); + dp_tiles = full_dispatch_tiles - num_cu; + sk_tiles = partial_dispatche_tiles + num_cu; + } + + // uint32_t dp_iters_per_block = k_iters_per_tile.get(); + uint32_t sk_total_iters = k_iters_per_tile.get() * sk_tiles; + uint32_t dp_num_blocks = 0; + + { + uint32_t min_sk_tiles = (sk_tiles >= num_cu) ? num_cu : (sk_tiles + 1); + uint32_t max_sk_tiles = + (sk_tiles >= num_cu) ? num_cu * sk_occupancy + : math::min(num_cu, sk_total_iters / min_k_iters_per_sk_block); + + // if use dp for sk-block, how many iters do we need + uint32_t dp_for_sk_iters = k_iters_per_tile.get(); + + uint32_t best_sk_score = + NumericLimits::Max(); // we need to find the smallest sk iters + for(uint32_t tentative_sk_blocks = min_sk_tiles; tentative_sk_blocks < max_sk_tiles; + tentative_sk_blocks++) + { + uint32_t tentative_sk_iters_per_block = + (sk_total_iters + tentative_sk_blocks - 1) / tentative_sk_blocks; + uint32_t tentative_sk_iters = tentative_sk_iters_per_block; + uint32_t sk_blocks_per_tile = (tentative_sk_blocks + sk_tiles - 1) / sk_tiles; + + // TODO: carefully adjust this parameter + // the more sk_blocks_per_tile, the worse the overhead + uint32_t cross_sk_blocks_overhead = sk_blocks_per_tile; + if(tentative_sk_blocks % sk_tiles != 0) + { + // penalty for uneven divide + cross_sk_blocks_overhead += + sk_blocks_per_tile * tentative_sk_iters_per_block / 50; + } + + uint32_t tentative_sk_score = tentative_sk_iters + cross_sk_blocks_overhead; + + if(tentative_sk_score < best_sk_score) + { + best_sk_score = tentative_sk_score; + sk_num_blocks = tentative_sk_blocks; + } + } + + if(best_sk_score >= dp_for_sk_iters) + { + sk_num_blocks = 0; + } + + // give a chance to control num of sk blocks + sk_num_blocks = sk_blocks != 0xffffffff ? sk_blocks : sk_num_blocks; + + if(sk_num_blocks == 0) + { + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + else + { + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + + dp_num_blocks = dp_tiles; + dp_start_block_idx = (sk_num_blocks + num_cu - 1) / num_cu * num_cu; + } + } + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = math::lcm(k_iters_per_big_block - 1, k_iters_per_tile.get()); + eqav_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + eqav_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + +#if 0 + printf("cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, " + "sk_num_blocks:%d, " + "sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, " + "k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, " + "sk_tiles:%u, workspace(acc float):%u\n", + num_cu, + occupancy, + get_grid_dims().x, + num_tiles, + dp_tiles, + sk_num_big_blocks, + sk_num_blocks, + sk_total_iters, + dp_start_block_idx, + dp_iters_per_block, + dp_num_blocks, + k_iters_per_tile.get(), + k_iters_per_big_block, + reduction_start_block_idx, + get_sk_tiles(), + get_workspace_size(sizeof(float))); +#endif + } + + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ dim3 get_grid_dims() const + { + if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction) + { + return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + } + else + return dim3(reduction_start_block_idx, 1, 1); + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& eqav_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_eqav_tiles_ = eqav_tiles_.get() - 1; + uint32_t quo_, rem_; + eqav_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_eqav_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, eqav_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, eqav_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, eqav_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, eqav_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, eqav_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + +template +struct BlockToCTileMap_GemmStreamK_v2 +{ + static constexpr uint32_t min_k_iters_per_sk_block = 2; + static constexpr uint32_t MPerBlock = MPerBlock_; + static constexpr uint32_t NPerBlock = NPerBlock_; + static constexpr uint32_t KPerBlock = KPerBlock_; + static constexpr uint32_t tile_swizzle_sub_m = TileSwizzleSubM_; + + //-------------------------------------- + // pass to device + mutable uint32_t sk_num_blocks; + uint32_t sk_num_big_blocks; + uint32_t dp_start_block_idx; + uint32_t reduction_start_block_idx; + uint32_t k_iters_per_big_block; + MDiv2 n_tiles; + MDiv k_iters_per_tile; + MDiv equiv_tiles_big; // for reduction + MDiv equiv_tiles_little; // for reduction + StreamKReductionStrategy reduction_strategy; + + // prefer construct on host + __host__ __device__ BlockToCTileMap_GemmStreamK_v2( + uint32_t m, + uint32_t n, + uint32_t k, + uint32_t grid_size = 1, + uint32_t streamk_sel = 1, + StreamKReductionStrategy reduction_strategy_ = StreamKReductionStrategy::Atomic) + : reduction_strategy(reduction_strategy_) + { + + // total output tiles + uint32_t num_tiles = + math::integer_divide_ceil(m, MPerBlock) * math::integer_divide_ceil(n, NPerBlock); + k_iters_per_tile = MDiv(math::integer_divide_ceil(k, KPerBlock)); + + uint32_t dp_tiles, dp_num_blocks, sk_total_iters; + + // Ensure grid_size is at least 1 to avoid division by zero + grid_size = math::max(grid_size, 1u); + + // default to regular DP GEMM if sk blocks == 0 + if(streamk_sel == 0) + { + sk_num_blocks = 0; + dp_tiles = num_tiles; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + + dp_num_blocks = num_tiles; // all tile to be dp block + dp_start_block_idx = 0; + sk_total_iters = 0; // clear this tiles + } + // 2-tile sk + DP GEMM + else + { + // check if there's enough work for DP+ stream-k + bool bigEnough = num_tiles > grid_size; + + // Select between stream-k strategies + // Add safety checks to prevent zero or negative values + uint32_t sk_tiles = 0; + if(streamk_sel == 1) // 1 tile stream-k + { + sk_tiles = bigEnough ? (num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 + sk_tiles = math::max(sk_tiles, 1u); + } + else if(streamk_sel == 2) // 2-tile stream-k + { + sk_tiles = bigEnough ? (grid_size + num_tiles % grid_size) : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); + } + else if(streamk_sel == 3) // 3-tile stream-k + { + sk_tiles = (num_tiles > (2 * grid_size)) ? (2 * grid_size + num_tiles % grid_size) + : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); + } + else if(streamk_sel == 4) // 4-tile stream-k + { + sk_tiles = (num_tiles > (3 * grid_size)) ? (3 * grid_size + num_tiles % grid_size) + : num_tiles; + + // Ensure sk_tiles is at least 1 but not more than num_tiles + sk_tiles = math::min(math::max(sk_tiles, 1u), num_tiles); + } + + sk_num_blocks = sk_tiles; + // Remaining tiles are DP tiles + dp_tiles = bigEnough ? (num_tiles - sk_tiles) : 0; + + sk_total_iters = k_iters_per_tile.get() * sk_tiles; + + // k_iters_per_sk_block is the floor of avg each ck block loop over tiles. + // we need to decide how many iters for each sk block + // let m = k_iters_per_sk_block + // some of the sk block (little) will cover m iters, some (big) will cover m+1 + // we have + // 1) l + b = sk_blocks + // 2) l * m + b * (m + 1) = sk_total_iters + // => (l + b) * m + b = sk_total_iters + // => sk_blocks * m + b = sk_total_iters + // => b = sk_total_iters - m * sk_blocks + // NOTE: big could be zero + + // Add safety check for sk_num_blocks to prevent division by zero + if(sk_num_blocks > 0) + { + uint32_t k_iters_per_sk_block = sk_total_iters / sk_num_blocks; + sk_num_big_blocks = sk_total_iters - k_iters_per_sk_block * sk_num_blocks; + k_iters_per_big_block = k_iters_per_sk_block + 1; + } + else + { + // Fallback to default GEMM if no stream-k blocks + sk_num_blocks = 0; + sk_num_big_blocks = 0; + k_iters_per_big_block = 0; + dp_tiles = num_tiles; + dp_num_blocks = num_tiles; + dp_start_block_idx = 0; + sk_total_iters = 0; + } + + dp_num_blocks = dp_tiles; + dp_start_block_idx = sk_num_blocks; + } + + n_tiles = MDiv2(math::integer_divide_ceil(n, NPerBlock)); + // Using multiple blocks for parallel reduction + reduction_start_block_idx = dp_start_block_idx + dp_num_blocks; + + if(reduction_strategy == ck::StreamKReductionStrategy::Reduction) + { + // Add additional safety checks + if(k_iters_per_big_block > 0 && k_iters_per_tile.get() > 0) + { + uint32_t upper_big = math::lcm(k_iters_per_big_block, k_iters_per_tile.get()); + uint32_t upper_little = + math::lcm(math::max(k_iters_per_big_block - 1, 1u), k_iters_per_tile.get()); + equiv_tiles_big = MDiv(upper_big / k_iters_per_tile.get()); + equiv_tiles_little = MDiv(upper_little / k_iters_per_tile.get()); + } + else + { + // Default safe values + equiv_tiles_big = MDiv(1); + equiv_tiles_little = MDiv(1); + } + } + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0; + } + __host__ __device__ uint32_t get_sk_total_iters() const + { + uint32_t sk_total_iters = sk_num_big_blocks * k_iters_per_big_block + + (sk_num_blocks - sk_num_big_blocks) * (k_iters_per_big_block - 1); + return sk_total_iters; + } + + __host__ __device__ uint32_t get_sk_tiles() const + { + // tiles for sk + uint32_t sk_total_iters = get_sk_total_iters(); + return k_iters_per_tile.div(sk_total_iters); + } + + __host__ __device__ index_t get_grid_dims() const + { + if(reduction_strategy == StreamKReductionStrategy::Reduction) + { + // return dim3(reduction_start_block_idx + get_sk_tiles(), 1, 1); + return reduction_start_block_idx + get_sk_tiles(); + } + else + return reduction_start_block_idx; + } + + __device__ uint32_t get_block_idx() const + { + // TODO: swizzle block index for better locality + return __builtin_amdgcn_readfirstlane(blockIdx.x); + } + + __device__ void + get_block_itr(uint32_t block_idx, uint32_t& iter_start, uint32_t& iter_end) const + { + if(block_idx < sk_num_big_blocks) + { + iter_start = block_idx * k_iters_per_big_block; + iter_end = iter_start + k_iters_per_big_block; + } + else if(block_idx < sk_num_blocks) + { + iter_start = (sk_num_big_blocks * k_iters_per_big_block) + + (block_idx - sk_num_big_blocks) * (k_iters_per_big_block - 1); + iter_end = iter_start + (k_iters_per_big_block - 1); + } + else if(block_idx >= dp_start_block_idx) + { + uint32_t sk_total_iters = get_sk_total_iters(); + uint32_t dp_iters_per_block = k_iters_per_tile.get(); + iter_start = sk_total_iters + (block_idx - dp_start_block_idx) * dp_iters_per_block; + iter_end = iter_start + dp_iters_per_block; + } + } + + __device__ uint32_t get_current_iter_length(uint32_t iter_start, + uint32_t iter_end, + uint32_t total_iter_length) const + { + uint32_t iter_length_mod, iter_length_quo /*unused*/; + k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); + uint32_t current_iter_length = math::min( + iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, total_iter_length); + return current_iter_length; + } + + __device__ uint32_t get_tile_idx(uint32_t iter) const { return k_iters_per_tile.div(iter); } + + __device__ void + get_tile_idx_with_offset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const + { + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); + } + + __device__ auto tile_to_spatial(uint32_t tile_idx, uint32_t m, uint32_t n) const + { + uint32_t m_tile_idx, n_tile_idx; + uint32_t n_tiles_value = math::integer_divide_ceil(n, NPerBlock); + n_tiles.divmod(tile_idx, n_tiles_value, m_tile_idx, n_tile_idx); + + // // swizzle tile + uint32_t m_tiles = math::integer_divide_ceil(m, MPerBlock); + + uint32_t tile_swizzle_sub_m_rem = m_tiles % tile_swizzle_sub_m; + + const auto sub_m_adapt = (m_tile_idx < (m_tiles - tile_swizzle_sub_m_rem)) + ? tile_swizzle_sub_m + : tile_swizzle_sub_m_rem; + + uint32_t m_tile_idx_sub0, m_tile_idx_sub1; + m_tile_idx_sub0 = m_tile_idx / tile_swizzle_sub_m; + m_tile_idx_sub1 = m_tile_idx % tile_swizzle_sub_m; + + uint32_t tile_idx_local = n_tile_idx + m_tile_idx_sub1 * n_tiles_value; + + uint32_t m_tile_idx_with_adapt, n_tile_idx_with_adapt; + + n_tile_idx_with_adapt = tile_idx_local / sub_m_adapt; + m_tile_idx_with_adapt = tile_idx_local % sub_m_adapt; + return make_tuple(m_tile_idx_with_adapt + m_tile_idx_sub0 * tile_swizzle_sub_m, + n_tile_idx_with_adapt); + } + + __host__ __device__ uint32_t get_workspace_size_for_acc(uint32_t acc_element_bytes) const + { + static constexpr uint32_t alignment = 128; + uint32_t acc_buffer_bytes = + MPerBlock * NPerBlock * get_total_acc_buffers() * acc_element_bytes; + return (acc_buffer_bytes + alignment - 1) / alignment * alignment; + } + + __host__ __device__ uint32_t get_workspace_size_for_semaphore() const + { + return get_sk_tiles() * sizeof(uint32_t); + } + + __host__ __device__ uint32_t get_workspace_size(uint32_t acc_element_bytes) const + { + return get_workspace_size_for_acc(acc_element_bytes) + get_workspace_size_for_semaphore(); + } + + __host__ __device__ uint32_t get_tile_intersections(uint32_t tiles_, + const MDiv& equiv_tiles_) const + { + uint32_t tile_idx_ = tiles_ == 0 ? 0 : (tiles_ - 1); + uint32_t max_equiv_tiles_ = equiv_tiles_.get() - 1; + uint32_t quo_, rem_; + equiv_tiles_.divmod(tile_idx_, quo_, rem_); + return quo_ * max_equiv_tiles_ + rem_; + } + + __host__ __device__ uint32_t get_tiles_cover_sk_block(uint32_t num_sk_blocks_, + uint32_t iters_per_sk_block_) const + { + return k_iters_per_tile.div(num_sk_blocks_ * iters_per_sk_block_ + k_iters_per_tile.get() - + 1); + } + + __host__ __device__ uint32_t get_total_acc_buffers() const + { + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + uint32_t tiles_cover_little_blocks = + get_tiles_cover_sk_block(sk_num_blocks - sk_num_big_blocks, k_iters_per_big_block - 1); + + uint32_t total_intersec_big = + get_tile_intersections(tiles_cover_big_blocks, equiv_tiles_big); + uint32_t total_intersec_little = + get_tile_intersections(tiles_cover_little_blocks, equiv_tiles_little); + + return sk_num_blocks + total_intersec_big + total_intersec_little; + } + + __device__ uint32_t get_acc_buffer_offset_from_tile(uint32_t tile_idx_) const + { + // TODO: from big to little + uint32_t tiles_cover_big_blocks = + get_tiles_cover_sk_block(sk_num_big_blocks, k_iters_per_big_block); + if(tile_idx_ < tiles_cover_big_blocks) + { + uint32_t touched_sk_blocks = + (tile_idx_ * k_iters_per_tile.get() + k_iters_per_big_block - 1) / + k_iters_per_big_block; + uint32_t current_intersec = get_tile_intersections(tile_idx_, equiv_tiles_big); + return touched_sk_blocks + current_intersec; + } + else + { + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + uint32_t tile_idx_little_reverse = get_sk_tiles() - tile_idx_; + uint32_t touched_sk_blocks = + (tile_idx_little_reverse * k_iters_per_tile.get() + iters_per_little_sk_block - 1) / + iters_per_little_sk_block; + uint32_t current_intersec = + get_tile_intersections(tile_idx_little_reverse, equiv_tiles_little); + return get_total_acc_buffers() - (touched_sk_blocks + current_intersec); + } + } + + __device__ uint32_t get_acc_buffer_offset_from_block(uint32_t block_idx_) const + { + uint32_t iters_per_big_sk_block = k_iters_per_big_block; + uint32_t iters_per_little_sk_block = k_iters_per_big_block - 1; + if(block_idx_ < sk_num_big_blocks) + { + uint32_t touched_tiles = k_iters_per_tile.div(block_idx_ * iters_per_big_sk_block + + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_big); + return block_idx_ + current_intersec; + } + else + { + uint32_t block_idx_little_reverse = sk_num_blocks - block_idx_; + uint32_t touched_tiles = k_iters_per_tile.div( + block_idx_little_reverse * iters_per_little_sk_block + k_iters_per_tile.get() - 1); + uint32_t current_intersec = get_tile_intersections(touched_tiles, equiv_tiles_little); + return get_total_acc_buffers() - (block_idx_little_reverse + current_intersec); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp new file mode 100755 index 000000000000..02dba97430d7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp @@ -0,0 +1,1130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// output : F[M, N0], where N0 is number of blocks along N dimension +// output : G[M, N0], where N0 is number of blocks along N dimension +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// F, G = welford(E) +// Assume: +// D0, D1, ... and E have the same layout +// Calculate mean & variance along N dimension for E +template +struct GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ABDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const EGridDescriptor_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDescriptor_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + template + __host__ __device__ static constexpr auto + MakeMeanVarCountGridDescriptor_MBlock_MPerBlock_NBlock(const GridDescriptor_M_N& grid_desc_m_n) + { + const auto M = grid_desc_m_n.GetLength(I0); + const auto NBlock = grid_desc_m_n.GetLength(I1); + const auto MBlock = M / MPerBlock; + + const auto grid_desc_mblock_mperblock_nblock = transform_tensor_descriptor( + grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_pass_through_transform(NBlock)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{})); + + return grid_desc_mblock_mperblock_nblock; + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EMeanVarDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DefaultAGridDesc_AK0_M_AK1 = + remove_cvref_t; + using DefaultBGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using MeanVarGridDescriptor_MBlock_MPerBlock_NBlock = + remove_cvref_t; + using CountGridDescriptor_MBlock_MPerBlock_NBlock = + remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void + Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EMeanVarDataType* __restrict__ p_e_grid, + EMeanVarDataType* __restrict__ p_welford_mean_grid, + EMeanVarDataType* __restrict__ p_welford_var_grid, + int32_t* __restrict__ p_welford_count, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const MeanVarGridDescriptor_MBlock_MPerBlock_NBlock& + mean_var_grid_desc_mblock_mperblock_nblock, + const CountGridDescriptor_MBlock_MPerBlock_NBlock& count_grid_desc_mblock_mperblock_nblock, + const Block2ETileMap& block_2_etile_map, + index_t NRaw) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto mean_grid_buf = make_dynamic_buffer( + p_welford_mean_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + auto var_grid_buf = make_dynamic_buffer( + p_welford_var_grid, mean_var_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + auto welford_count_grid_buf = make_dynamic_buffer( + p_welford_count, count_grid_desc_mblock_mperblock_nblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataType, + ABDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C, Welford and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence, + false>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_der_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, + false>{}; + + // LDS c_shuffle_block_desc_mperblock_nperblock + constexpr auto c_shuffle_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(PostShuffleThreadClusterSize_M_N::At(I0) * + PostShuffleThreadClusterSize_M_N::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + PostShuffleThreadClusterSize_M_N::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + PostShuffleThreadClusterSize_M_N::At(I1) == + 0, + "wrong!"); + + constexpr index_t PostShuffleThreadSliceSize_M = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + PostShuffleThreadClusterSize_M_N::At(I0); + + constexpr index_t PostShuffleThreadSliceSize_N = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + PostShuffleThreadClusterSize_M_N::At(I1); + + constexpr auto PostShuffleThreadSliceSize_M_N = + Sequence{}; + + // VGPR post_shuffle_thread_desc_m_n + constexpr auto post_shuffle_thread_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{})); + + auto e_thread_buf = make_static_buffer( + post_shuffle_thread_desc_m_n.GetElementSpaceSize()); + + // To apply D0, D1, ... and Welford. + // threadwise copy from LDS to VGPR + constexpr auto post_shuffle_thread_cluster_desc = + make_cluster_descriptor(PostShuffleThreadClusterSize_M_N{}, Sequence<0, 1>{}); + + const auto post_shuffle_thread_cluster_idx = + post_shuffle_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto post_shuffle_thread_data_idx_begin = + post_shuffle_thread_cluster_idx * PostShuffleThreadSliceSize_M_N; + + // To apply D0, D1, ... and Welford. + // Copy c shuffle from LDS back to VGPR + auto post_shuffle_thread_copy_lds_to_vgpr = + ThreadwiseTensorSliceTransfer_v2, + 1, + PostShuffleScalarPerVector, + 1, + true>{c_shuffle_block_desc_mperblock_nperblock, + post_shuffle_thread_data_idx_begin}; + + // D0, D1, ..., Dn + constexpr auto post_shuffle_thread_desc_I1_mperblock_I1_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + // FIXME: Decrease usage of VGPR + // Apply pointwise lambda function from multi-source (Global and LDS) into VGPR + auto ds_thread_buf = generate_tuple( + [&](auto) { + return make_static_buffer( + post_shuffle_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize()); + }, + Number{}); + + // Copy D0, D1, ..., Dn from global to VGPR + auto ds_thread_copy_global_to_vgpr = generate_tuple( + [&](auto I) { + using DDataType = remove_cvref_t>; + return ThreadwiseTensorSliceTransfer_v2< + DDataType, + AccDataType, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), + decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + PostShuffleScalarPerVector, + 1, + true>( + ds_grid_desc_mblock_mperblock_nblock_nperblock[I], + make_multi_index( + I0, + m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1])); + }, + Number{}); + + auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EMeanVarDataType, + decltype(post_shuffle_thread_desc_I1_mperblock_I1_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + PostShuffleScalarPerVector, + InMemoryDataOperationEnum::Set, + 1, + true>{ + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + post_shuffle_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + // Welford + constexpr auto thread_welford_src_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{})); + + constexpr auto thread_welford_dst_desc_m = make_naive_tensor_descriptor_packed( + make_tuple(Number{})); + + using ThreadwiseWelford = ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford, + false>; + + constexpr int num_shuffleM = + MPerBlock / (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl); + + constexpr int num_shuffleN = + NPerBlock / (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl); + + using mean_var_vgpr_type = + decltype(make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize())); + + using welford_count_vgpr_type = + decltype(make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize())); + + Array threadwise_welfords; + Array mean_thread_bufs; + Array var_thread_bufs; + Array welford_count_thread_bufs; + + int max_count = PostShuffleThreadSliceSize_N * num_shuffleN; + const auto nblock = mean_var_grid_desc_mblock_mperblock_nblock.GetLength(I2); + + // tail block + if(block_work_idx[I1] % nblock == nblock - 1) + { + constexpr index_t NPerShuffleBlock = + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl; + + int NPerBlockTail = NRaw - NPerBlock * (nblock - 1); + int thread_max_len = + PostShuffleThreadSliceSize_N * (post_shuffle_thread_cluster_idx[I1] + 1); + int shuffle_step = 0; + while(thread_max_len <= NPerBlockTail && shuffle_step < num_shuffleN) + { + ++shuffle_step; + thread_max_len += NPerShuffleBlock; + } + + int delta = 0; + if(thread_max_len - NPerBlockTail > PostShuffleThreadSliceSize_N) + delta = 0; + else if(NPerBlockTail > thread_max_len) + delta = PostShuffleThreadSliceSize_N; + else + delta = PostShuffleThreadSliceSize_N - thread_max_len + NPerBlockTail; + + max_count = shuffle_step * PostShuffleThreadSliceSize_N + delta; + } + + static_for<0, num_shuffleM, 1>{}([&](auto i) { + threadwise_welfords(i).max_count_ = max_count; + mean_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + var_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + welford_count_thread_bufs(i) = make_static_buffer( + thread_welford_dst_desc_m.GetElementSpaceSize()); + + static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { + mean_thread_bufs(i)(j) = type_convert(0.0f); + var_thread_bufs(i)(j) = type_convert(0.0f); + welford_count_thread_bufs(i)(j) = 0; + }); + }); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!"); + + int shuffleM_index = __builtin_amdgcn_readfirstlane(0); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to read from LDS + block_sync_lds(); + + // each thread shuffle data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + + // Get shuffle data from LDS to VGPR + post_shuffle_thread_copy_lds_to_vgpr.Run(c_shuffle_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + post_shuffle_thread_desc_m_n, + make_tuple(I0, I0), + e_thread_buf); + + // Global read D0, D1, ... + static_for<0, NumDTensor, 1>{}([&](auto Id) { + auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id); + d_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], + ds_grid_buf[Id], + post_shuffle_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + ds_thread_buf(Id)); + + if constexpr(access_id < num_access - 1) + { + // move on D0, D1, ... + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + d_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step); + } + }); + + // cde_element_op(e, c, d0, d1, ...); + static_for<0, post_shuffle_thread_desc_m_n.GetElementSize(), 1>{}([&](auto i) { + const auto c_ds_src_data_refs = concat_tuple_of_reference( + tie(e_thread_buf[i]), + generate_tie( + [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, + Number{})); + auto e_dst_data_refs = tie(e_thread_buf(i)); + unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); + }); + + // Global write E + e_thread_copy_vgpr_to_global.Run(post_shuffle_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + if constexpr(access_id < num_access - 1) + { + // move on E + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + e_thread_copy_vgpr_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step); + } + + // Threadwise welford + auto& threadwise_welford = threadwise_welfords(shuffleM_index); + auto& mean_thread_buf = mean_thread_bufs(shuffleM_index); + auto& var_thread_buf = var_thread_bufs(shuffleM_index); + + threadwise_welford.Run(e_thread_buf, mean_thread_buf, var_thread_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + constexpr int shuffleMInc = + de_global_step[I1] / + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); + shuffleM_index = __builtin_amdgcn_readfirstlane(shuffleM_index + shuffleMInc); + } + }); // copy c, d, e + welford + + // Blockwise welford and write out + static_for<0, num_shuffleM, 1>{}([&](auto i) { + auto& mean_thread_buf = mean_thread_bufs(i); + auto& var_thread_buf = var_thread_bufs(i); + auto& count_thread_buf = welford_count_thread_bufs(i); + + static_for<0, PostShuffleThreadSliceSize_M, 1>{}([&](auto j) { + block_sync_lds(); + count_thread_buf(j) = threadwise_welfords(i).cur_count_; + BlockwiseWelford::Run( + mean_thread_buf(j), var_thread_buf(j), count_thread_buf(j)); + }); + + if(post_shuffle_thread_cluster_idx[I1] == 0) + { + constexpr auto thread_welford_desc_I_m_I = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1)); + + constexpr int shuffleMPerBlock = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1); + + auto mean_var_count_thread_copy_index = make_multi_index( + block_work_idx[I0], // mblock + shuffleMPerBlock * i + post_shuffle_thread_data_idx_begin[I0], // mperblock + block_work_idx[I1]); // nblock + + auto mean_var_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + EMeanVarDataType, + decltype(thread_welford_desc_I_m_I), + decltype(mean_var_grid_desc_mblock_mperblock_nblock), + tensor_operation::element_wise::PassThrough, + Sequence<1, PostShuffleThreadSliceSize_M, 1>, + Sequence<0, 1, 2>, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{mean_var_grid_desc_mblock_mperblock_nblock, + mean_var_count_thread_copy_index, + tensor_operation::element_wise::PassThrough{}}; + + mean_var_thread_copy_vgpr_to_global.Run( + thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + mean_thread_buf, + mean_var_grid_desc_mblock_mperblock_nblock, + mean_grid_buf); // write mean + + mean_var_thread_copy_vgpr_to_global.Run( + thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + var_thread_buf, + mean_var_grid_desc_mblock_mperblock_nblock, + var_grid_buf); // write variance + + // Stride of count is [0, 1]. Only the first row in count[0, 0:nblock] need + // to be written. + if(i == 0 && block_work_idx[I0] == 0 && + post_shuffle_thread_cluster_idx[I0] == 0) + { + auto count_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + int32_t, + int32_t, + decltype(thread_welford_desc_I_m_I), + decltype(count_grid_desc_mblock_mperblock_nblock), + tensor_operation::element_wise::PassThrough, + Sequence<1, PostShuffleThreadSliceSize_M, 1>, + Sequence<0, 1, 2>, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + false>{count_grid_desc_mblock_mperblock_nblock, + mean_var_count_thread_copy_index, + tensor_operation::element_wise::PassThrough{}}; + + count_thread_copy_vgpr_to_global.Run( + thread_welford_desc_I_m_I, + make_tuple(I0, I0, I0), + count_thread_buf, + count_grid_desc_mblock_mperblock_nblock, + welford_count_grid_buf); // write count + } + } + }); + + } // shuffle C + Ds + welford + write out + } // run +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp new file mode 100755 index 000000000000..69468c25befe --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp @@ -0,0 +1,394 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" + +namespace ck { + +template +struct GridwiseWelfordSecondHalfLayernorm2d +{ + static_assert(NThreadSliceSize % ESrcVectorSize == 0 && + NThreadSliceSize % GammaSrcVectorSize == 0 && + NThreadSliceSize % BetaSrcVectorSize == 0, + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NThreadSliceSize % HDstVectorSize == 0, + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_N = Sequence; + using ThreadBufferDimAccessOrder = Sequence<0, 1>; + using ThreadClusterArrangeOrder = Sequence<0, 1>; + + static constexpr auto thread_cluster_desc_m_n = + make_cluster_descriptor(ThreadClusterLengths_M_N{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_N = Sequence; + static constexpr auto thread_buffer_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number<1>{})); + + using ThreadBufferLengths_N = Sequence; + static constexpr auto thread_buffer_desc_n = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1); + using ThreadWelfordDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t N_BlockTileSize = NThreadClusterSize * NThreadSliceSize; + + __device__ static void Run(const EMeanVarDataType* __restrict__ p_e_grid, + const EMeanVarDataType* __restrict__ p_in_welford_mean_grid, + const EMeanVarDataType* __restrict__ p_in_welford_var_grid, + const int32_t* __restrict__ p_in_welford_count_grid, + const GammaDataType* __restrict__ p_gamma_grid, + const BetaDataType* __restrict__ p_beta_grid, + HDataType* __restrict__ p_h_grid, + const EHGridDesc_M_N& e_grid_desc_m_n, + const EHGridDesc_M_N& h_grid_desc_m_n, + const MeanVarGridDesc_M_NBlock& mean_var_grid_desc_m_nblock, + const CountGridDesc_M_NBlock& count_grid_desc_m_nblock, + const GammaBetaGridDesc_N& gamma_grid_desc_n, + const GammaBetaGridDesc_N& beta_grid_desc_n, + index_t numMeanVarCountBlockTileIteration_N, + index_t NBlockClusterLength, + ComputeDataType epsilon, + HElementwiseOperation h_element_op) + { + // Thread/Block id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const auto block_work_idx = make_tuple(block_global_id / NBlockClusterLength, + block_global_id % NBlockClusterLength); + + const auto thread_cluster_idx = + thread_cluster_desc_m_n.CalculateBottomIndex(make_multi_index(thread_local_id)); + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_n_cluster_id = thread_cluster_idx[I1]; + + // Global Memory + const auto e_global_val_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_m_n.GetElementSpaceSize()); + + const auto welford_mean_global_val_buf = make_dynamic_buffer( + p_in_welford_mean_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize()); + + const auto welford_var_global_val_buf = make_dynamic_buffer( + p_in_welford_var_grid, mean_var_grid_desc_m_nblock.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_in_welford_count_grid, count_grid_desc_m_nblock.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_grid, gamma_grid_desc_n.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_grid, beta_grid_desc_n.GetElementSpaceSize()); + + auto h_global_val_buf = make_dynamic_buffer( + p_h_grid, h_grid_desc_m_n.GetElementSpaceSize()); + + // VGPR + StaticBuffer + in_welford_mean_thread_buf; + StaticBuffer + in_welford_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + + StaticBuffer + welford_mean_thread_buf; + StaticBuffer + welford_var_thread_buf; + StaticBuffer + welford_count_thread_buf; + + StaticBuffer + e_thread_buf; + StaticBuffer + gamma_thread_buf; + StaticBuffer + beta_thread_buf; + StaticBuffer + h_thread_buf; + + // IO + auto threadwise_mean_load_m_nblock = + ThreadwiseTensorSliceTransfer_v2( + mean_var_grid_desc_m_nblock, + make_multi_index(block_work_idx[I0] * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_n_cluster_id)); + + auto threadwise_var_load_m_nblock = + ThreadwiseTensorSliceTransfer_v2( + mean_var_grid_desc_m_nblock, + make_multi_index(block_work_idx[I0] * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_n_cluster_id)); + + auto threadwise_count_load_m_nblock = + ThreadwiseTensorSliceTransfer_v2( + count_grid_desc_m_nblock, + make_multi_index(block_work_idx[I0] * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_n_cluster_id)); + + auto threadwise_e_load_m_n = + ThreadwiseTensorSliceTransfer_v2( + e_grid_desc_m_n, + make_multi_index( + block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize)); + + auto threadwise_gamma_load_n = + ThreadwiseTensorSliceTransfer_v2, // DimAccessOrder, + 0, // SrcVectorDim, + GammaSrcVectorSize, + 1, + true>( + gamma_grid_desc_n, + make_multi_index(block_work_idx[I1] * N_BlockTileSize + + thread_n_cluster_id * NThreadSliceSize)); + + auto threadwise_beta_load_n = + ThreadwiseTensorSliceTransfer_v2, // DimAccessOrder, + 0, // SrcVectorDim, + BetaSrcVectorSize, + 1, + true>( + beta_grid_desc_n, + make_multi_index(block_work_idx[I1] * N_BlockTileSize + + thread_n_cluster_id * NThreadSliceSize)); + + auto threadwise_h_store_m_n = + ThreadwiseTensorSliceTransfer_v1r3( + h_grid_desc_m_n, + make_multi_index( + block_work_idx[I0] * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_work_idx[I1] * N_BlockTileSize + thread_n_cluster_id * NThreadSliceSize), + h_element_op); + + // step1: Merge mean and variance + constexpr auto mean_var_count_thread_copy_step_I0_n = + make_multi_index(I0, NThreadClusterSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + welford_mean_thread_buf(I) = type_convert(0.0f); + welford_var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t n = 0; n < numMeanVarCountBlockTileIteration_N; ++n) + { + threadwise_mean_load_m_nblock.Run(mean_var_grid_desc_m_nblock, + welford_mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_mean_thread_buf); + + threadwise_var_load_m_nblock.Run(mean_var_grid_desc_m_nblock, + welford_var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_var_thread_buf); + + threadwise_count_load_m_nblock.Run(count_grid_desc_m_nblock, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_welford_mean_thread_buf, + in_welford_var_thread_buf, + in_welford_count_thread_buf, + welford_mean_thread_buf, + welford_var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock, + mean_var_count_thread_copy_step_I0_n); + threadwise_var_load_m_nblock.MoveSrcSliceWindow(mean_var_grid_desc_m_nblock, + mean_var_count_thread_copy_step_I0_n); + threadwise_count_load_m_nblock.MoveSrcSliceWindow(count_grid_desc_m_nblock, + mean_var_count_thread_copy_step_I0_n); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I)); + }); + + // step2: normalization + // h[m, n] = [(e[m, n] - mean[m]) / sqrt(var[m] + eps)] * gamma[n] + beta[n] + threadwise_e_load_m_n.Run(e_grid_desc_m_n, + e_global_val_buf, + thread_buffer_desc_m_n, + make_tuple(I0, I0), + e_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto m) { + auto divisor = 1 / ck::math::sqrt(welford_var_thread_buf(m) + epsilon); + static_for<0, NThreadSliceSize, 1>{}([&](auto n) { + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = + (e_thread_buf(Number{}) - welford_mean_thread_buf(m)) * divisor; + }); + }); + + threadwise_gamma_load_n.Run(gamma_grid_desc_n, + gamma_global_val_buf, + thread_buffer_desc_n, + make_tuple(I0), + gamma_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto m) { + static_for<0, NThreadSliceSize, 1>{}([&](auto n) { + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) * gamma_thread_buf(n); + }); + }); + + threadwise_beta_load_n.Run(beta_grid_desc_n, + beta_global_val_buf, + thread_buffer_desc_n, + make_tuple(I0), + beta_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto m) { + static_for<0, NThreadSliceSize, 1>{}([&](auto n) { + constexpr auto m_n = thread_buffer_desc_m_n.CalculateOffset(make_tuple(m, n)); + h_thread_buf(Number{}) = h_thread_buf(Number{}) + beta_thread_buf(n); + }); + }); + + threadwise_h_store_m_n.Run(thread_buffer_desc_m_n, + make_tuple(I0, I0), + h_thread_buf, + h_grid_desc_m_n, + h_global_val_buf); + + } // run +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp new file mode 100755 index 000000000000..2d9197a7f443 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_multiblock.hpp @@ -0,0 +1,321 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +kernel_multiple_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M_Tuple out_grid_desc_m_tuple, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple, + index_t block_group_size, + index_t num_k_block_tile_iteration, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) +{ + GridwiseMultipleReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_tuple, + in_elementwise_op_tuple, + acc_elementwise_op_tuple, + block_group_size, + num_k_block_tile_iteration, + alpha_values, + p_in_value_global, + beta_values, + p_out_value_global_tuple); +}; + +template +struct GridwiseMultipleReduction_mk_to_m_multiblock +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypePointerTuple::Size() && + NumReduction == OutGridDesc_M_Tuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M_Tuple& out_grid_desc_m_tuple, + const InElementwiseOperationTuple& in_elementwise_op_tuple, + const AccElementwiseOperationTuple& acc_elementwise_op_tuple, + index_t block_group_size, + index_t num_k_block_tile_iteration, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + // LDS, reused by all reductions + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf_tuple = generate_tuple( + [&](auto iR) { + return make_dynamic_buffer( + p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize()); + }, + Number{}); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + auto in_thread_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + auto accu_value_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}( + [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; }); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, NumReduction, 1>{}([&](auto iR) { + using OutDataTypePointer = remove_cvref_t; + using OutDataType = remove_cvref_t>; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf_tuple(iR)(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I), + accu_value_buf_tuple(iR)(I)); + + accu_value_buf_tuple(iR)(I) *= alpha_values[iR]; + } + }); + + if(thread_k_cluster_id == 0) + { + if(!float_equal_zero{}(beta_values[iR])) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + 1, + false>( + out_grid_desc_m_tuple[iR], + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR), + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf_tuple(iR)(I) += + type_convert(priorDstValueBuf[I]) * beta_values[iR]; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m_tuple[iR], + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf_tuple[iR], + out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR)); + }; + }); + }; +}; // namespace ck + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp new file mode 100755 index 000000000000..fc4f27e33b15 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_multiple_reduction_threadwise.hpp @@ -0,0 +1,264 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M_Tuple out_grid_desc_m_tuple, + const InElementwiseOperationTuple in_elementwise_op_tuple, + const AccElementwiseOperationTuple acc_elementwise_op_tuple, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) +{ + GridwiseMultipleReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_tuple, + in_elementwise_op_tuple, + acc_elementwise_op_tuple, + alpha_values, + p_in_value_global, + beta_values, + p_out_value_global_tuple); +}; + +template +struct GridwiseMultipleReduction_mk_to_m_threadwise +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(NumReduction == OutDataTypePointerTuple::Size() && + NumReduction == OutGridDesc_M_Tuple::Size() && + NumReduction == OutDstVectorSizeSeq::Size() && + NumReduction == InElementwiseOperationTuple::Size() && + NumReduction == AccElementwiseOperationTuple::Size(), + "All tuple should have the same size as the number of Reductions!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M_Tuple& out_grid_desc_m_tuple, + const InElementwiseOperationTuple& in_elementwise_op_tuple, + const AccElementwiseOperationTuple& acc_elementwise_op_tuple, + Array alpha_values, + const InDataType* const __restrict__ p_in_value_global, + Array beta_values, + OutDataTypePointerTuple p_out_value_global_tuple) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf_tuple = generate_tuple( + [&](auto iR) { + return make_dynamic_buffer( + p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize()); + }, + Number{}); + + StaticBuffer + in_thread_buf; + + auto in_thread_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + auto accu_value_buf_tuple = generate_tuple( + [&](auto iR) { + (void)iR; + return StaticBuffer{}; + }, + Number{}); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}( + [&](auto J) { accu_value_buf_tuple(iR)(J) = identityVal; }); + }); + + const index_t thread_global_1d_id = get_thread_global_1d_id(); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, NumReduction, 1>{}([&](auto iR) { + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf_tuple(iR), accu_value_buf_tuple(iR)); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, NumReduction, 1>{}([&](auto iR) { + using OutDataTypePointer = remove_cvref_t; + using OutDataType = remove_cvref_t>; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I), + accu_value_buf_tuple(iR)(I)); + + accu_value_buf_tuple(iR)(I) *= alpha_values[iR]; + }); + + if(!float_equal_zero{}(beta_values[iR])) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + 1, + false>( + out_grid_desc_m_tuple[iR], + make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR), + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf_tuple(iR)(I) += + type_convert(priorDstValueBuf[I]) * beta_values[iR]; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSizeSeq::At(iR), + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m_tuple[iR], + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf_tuple[iR], + out_grid_desc_m_tuple[iR], + out_global_val_buf_tuple(iR)); + }); + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp new file mode 100755 index 000000000000..774df1f99348 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_multiblock.hpp @@ -0,0 +1,613 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + (void)p_in_index_global; + (void)p_out_index_global; + + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_multiblock +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + using Accumulation = detail::AccumulateWithNanCheck; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}( + [&](auto I) { BlockwiseReduce::Reduce(reduce_work_buf, accu_value_buf(I)); }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(!float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + false>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + true>( + out_grid_desc_m, + make_multi_index(blkgroup_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + } + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using BlockwiseReduceWithIndex = + PartitionedBlockwiseReductionWithIndex, + ThreadClusterArrangeOrder, + ReduceOperation, + PropagateNan>; + + using AccumulationWithIndex = detail::AccumulateWithIndexAndNanCheck; + + (void)in_elementwise_op; + + // LDS + __shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; + __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + auto reduce_work_val_buf = + make_dynamic_buffer(p_reduce_work_val_buffer, BlockSize); + auto reduce_work_idx_buf = + make_dynamic_buffer(p_reduce_work_idx_buffer, BlockSize); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_1d_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = identityVal; + accu_index_buf(I) = 0; + }); + + constexpr auto in_thread_copy_step = make_multi_index(0, K_BlockTileSize); + + index_t reducedTiles = 0; + + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + index_t indexOffset = 0; + + do + { + // load the thread slice + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + // initialize the indices for the per-thread to-reduce values + in_thread_idx_buf(Number{}) = + indexOffset + thread_k_cluster_id * KThreadSliceSize + iK(); + + // do element-wise pre-reduction operation + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + + AccDataType tmpValue = identityVal; + IndexDataType tmpIndex = 0; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + AccumulationWithIndex::Calculate(tmpValue, + in_thread_val_buf[Number{}], + tmpIndex, + in_thread_idx_buf[Number{}]); + }); + + BlockwiseReduceWithIndex::Reduce( + reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex); + + AccumulationWithIndex::Calculate( + accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex); + }); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexOffset += K_BlockTileSize; + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + }; + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if(thread_k_cluster_id == 0) + { + // for indiced operation, acc_elementwise_op shoud do nothing + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + } + }); + + if(thread_k_cluster_id == 0) + { + if(!float_equal_zero{}(beta)) + { + StaticBuffer + priorDstValueBuf; + + auto threadwise_dst_load = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + OutDstVectorSize, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValueBuf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValueBuf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + out_grid_desc_m, + make_multi_index(block_global_1d_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run(reduced_data_desc, + make_tuple(I0), + accu_value_buf, + out_grid_desc_m, + out_global_val_buf); + threadwise_dst_idx_store.Run(reduced_data_desc, + make_tuple(I0), + accu_index_buf, + out_grid_desc_m, + out_global_idx_buf); + } + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp new file mode 100755 index 000000000000..910c926c7e47 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp @@ -0,0 +1,488 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const AccElementwiseOperation acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) +{ + if constexpr(!OutputIndex) + { + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + beta, + p_out_value_global); + } + else + { + GridwiseReduction::template RunWithIndex( + in_grid_desc_m_k, + out_grid_desc_m, + in_elementwise_op, + acc_elementwise_op, + alpha, + p_in_value_global, + p_in_index_global, + beta, + p_out_value_global, + p_out_index_global); + }; +}; + +template +struct GridwiseReduction_mk_to_m_threadwise +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + using ThreadwiseReduce = ThreadwiseReduction; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto dst_global_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + true>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + dst_global_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf); + }; + + template + __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const AccElementwiseOperation& acc_elementwise_op, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + const IndexDataType* const __restrict__ p_in_index_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global, + IndexDataType* const __restrict__ p_out_index_global) + { + using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex; + + (void)acc_elementwise_op; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + const auto in_global_idx_buf = make_dynamic_buffer( + p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + auto out_global_idx_buf = make_dynamic_buffer( + p_out_index_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_val_buf; + + StaticBuffer + in_thread_idx_buf; + + StaticBuffer accu_value_buf; + StaticBuffer accu_index_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = identityVal; + accu_index_buf(I) = 0; + }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t indexStart = 0; + index_t reducedLength = 0; + if constexpr(HaveIndexInput) + { + auto threadwise_src_idx_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + threadwise_src_idx_load.Run(in_grid_desc_m_k, + in_global_idx_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_idx_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + } + else + { + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_val_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + + in_thread_idx_buf(Number{}) = indexStart + iK(); + + in_elementwise_op(in_thread_val_buf(Number{}), + in_thread_val_buf(Number{})); + }); + }); + + ThreadwiseReduceWithIndex::Reduce( + in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + indexStart += KThreadSliceSize; + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + if constexpr(TransformIndexKtoGlobal) + { + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + const auto coord = make_tensor_coordinate( + in_grid_desc_m_k, + make_multi_index(thread_global_1d_id * MThreadSliceSize + I, + accu_index_buf(I))); + + accu_index_buf(I) = coord.GetOffset(); + }); + } + }; + + // for indiced operation, acc_elementwise_op shoud do nothing + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + acc_elementwise_op(accu_value_buf(I), accu_value_buf(I)); + + accu_value_buf(I) *= alpha; + }); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + if(!float_equal_zero{}(beta)) + { + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2, + Sequence<0>, + 0, + 1, + 1, + false>( + out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize)); + + StaticBuffer + priorDstValue_buf; + + threadwise_dst_load.Run(out_grid_desc_m, + out_global_val_buf, + reduced_data_desc, + make_tuple(I0), + priorDstValue_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) += type_convert(priorDstValue_buf[I]) * beta; + }); + }; + + auto threadwise_dst_val_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dst_idx_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_dst_val_store.Run( + reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf); + + threadwise_dst_idx_store.Run( + reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf); + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp new file mode 100755 index 000000000000..e3c50ef06caa --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise_multi_d.hpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/tuple_helper.hpp" + +namespace ck { + +template +__global__ void +kernel_reduce_threadwise_multi_d(const InGridDesc_M_K in_grid_desc_m_k, + const DsGridDesc_M ds_grid_desc_m, + const OutGridDesc_M out_grid_desc_m, + const InElementwiseOperation in_elementwise_op, + const OutElementwiseOperation out_elementwise_op, + const InDataType* const __restrict__ p_in_value_global, + const DsGridPointer p_ds_value_global, + OutDataType* const __restrict__ p_out_value_global) +{ + GridwiseReduction::Run(in_grid_desc_m_k, + ds_grid_desc_m, + out_grid_desc_m, + in_elementwise_op, + out_elementwise_op, + p_in_value_global, + p_ds_value_global, + p_out_value_global); +} + +template +struct GridwiseReduction_mk_to_m_threadwise_multi_d +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (MThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThrough = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k, + const DsGridDesc_M& ds_grid_desc_m, + const OutGridDesc_M& out_grid_desc_m, + const InElementwiseOperation& in_elementwise_op, + const OutElementwiseOperation& out_elementwise_op, + const InDataType* const __restrict__ p_in_value_global, + const DsGridPointer p_ds_grid, + OutDataType* const __restrict__ p_out_value_global) + { + using ThreadwiseReduce = ThreadwiseReduction; + + const auto identityVal = ReduceOperation::template GetIdentityValue(); + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, + in_grid_desc_m_k.GetElementSpaceSize(), + ReduceOperation::template GetIdentityValue()); + auto dst_global_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); + + StaticBuffer + in_thread_buf; + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; }); + + const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{}); + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); + + auto threadwise_src_val_load = + ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0)); + + constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize); + + index_t reducedLength = 0; + do + { + threadwise_src_val_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // do element-wise pre-reduction operation + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + in_elementwise_op(in_thread_buf(Number{}), + in_thread_buf(Number{})); + }); + }); + + ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf); + + threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step); + + reducedLength += KThreadSliceSize; + } while(reducedLength < toReduceLength); + + constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{}; + + auto ds_thread_buf = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto ds_global_buf = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_ds_grid[I], ds_grid_desc_m[I].GetElementSpaceSize()); + }, + Number{}); + + auto ds_global_load = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + InSrcVectorDim, // SrcVectorDim + DsVectorSize{}[I], + 1, // SrcScalarStrideInVector + true>{ + ds_grid_desc_m[I], make_multi_index(thread_global_1d_id * MThreadSliceSize)}; + }, + Number{}); + + static_for<0, NumDTensor, 1>{}([&](auto I) { + ds_global_load(I).Run(ds_grid_desc_m[I], + ds_global_buf[I], + reduced_data_desc, + make_tuple(I0), + ds_thread_buf(I)); + }); + + StaticBuffer out_value_buf; + + // if constexpr(NumDTensor > 0) + { + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(accu_value_buf[I]), + generate_tie( + [&](auto Id) -> const auto& { return ds_thread_buf[Id][I]; }, + Number{})); + + unpack2(out_elementwise_op, tie(out_value_buf(I)), c_ds_buf_refs); + }); + } + + auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + OutDstVectorSize, + OutMemoryDataOperation, + 1, + false>( + out_grid_desc_m, + make_multi_index(thread_global_1d_id * MThreadSliceSize), + PassThrough{}); + + threadwise_dst_store.Run( + reduced_data_desc, make_tuple(I0), out_value_buf, out_grid_desc_m, dst_global_buf); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..258d0ad0caa5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,946 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmGemm_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + // Gemm0 + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + // Gemm1 + static constexpr auto B1K0 = Number{}; + static constexpr auto B1K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_space_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + const auto a_block_reset_copy_step = + make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b_block_reset_copy_step = + make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // + // set up Gemm1 + // + + // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto AccN3 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = + make_tuple(Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + decltype(acc_element_op), + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{acc_element_op}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; + + auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + false, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); + + const index_t num_gemm1_k_block_outer_loop = + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + + // Initialize C + c_thread_buf.Clear(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + // gemm0 + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + acc_thread_buf, + num_k_block_main_loop); + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for gemm0 LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf); + } + } // end gemm1 + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, + a_block_reset_copy_step); // rewind K + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, + b_block_reset_copy_step); // rewind K and step N + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..53a45c7f1675 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp @@ -0,0 +1,1299 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + static constexpr index_t NumD1Tensor = D1sDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto WaveSize = 64; + // K1 should be Number<...> + // Gemm0 + static constexpr auto A0K1 = Number{}; + static constexpr auto B0K1 = Number{}; + + static constexpr auto A0K0PerBlock = Number{}; + static constexpr auto B0K0PerBlock = Number{}; + + static constexpr auto Gemm0MWaves = Gemm0MPerBlock / (Gemm0MPerXdl * Gemm0MXdlPerWave); + static constexpr auto Gemm0NWaves = Gemm0NPerBlock / (Gemm0NPerXdl * Gemm0NXdlPerWave); + // Gemm1 + static constexpr auto B1K1 = Number{}; + static constexpr auto B1K0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + // ck::Tuple + static constexpr auto MakeD0sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeD1sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D1DataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __device__ static auto GetGemm0WaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetGemm0WaveMNIdx(const index_t thread_id) + { + constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(WaveSize / Gemm0NPerXdl, Gemm0NPerXdl))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + A0BlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = Gemm0NPerBlock / (Gemm0NXdlPerWave * Gemm0NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const A0BlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + A0BlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A0 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(A0K0PerBlock, Number{}, A0K1), + make_tuple(Number{} * A0K1, A0K1, I1)); + } + + __host__ __device__ static constexpr auto GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B0 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B0K0PerBlock, Number{}, B0K1), + make_tuple(Number{} * B0K1, B0K1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0PerBlock, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl); + + constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a0_block_space_size_aligned + + SharedMemTrait::b0_block_space_size_aligned) * + sizeof(A0B0B1DataType); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(A0B0B1DataType); + const index_t c1_block_bytes_end = + SharedMemTrait::c1_block_space_size * sizeof(C1ShuffleDataType); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, c1_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const A0GridDesc_M_K& a0_grid_desc_m_k, + const B0GridDesc_N_K& b0_grid_desc_n_k, + const B1GridDesc_N_K& b1_grid_desc_n_k, + const E1GridDesc_M_N& e1_grid_desc_m_n, + const Block2E1TileMap& block_2_e1tile_map) + { + static_assert((Gemm0MPerBlock % (Gemm0MPerXdl * Gemm0MXdlPerWave) == 0) && + (Gemm0NPerBlock % (Gemm0NXdlPerWave * Gemm0NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a0_grid_desc_m_k.GetLength(I0); + const auto N = b0_grid_desc_n_k.GetLength(I0); + const auto K = a0_grid_desc_m_k.GetLength(I1); + const auto Gemm1N = b1_grid_desc_n_k.GetLength(I0); + + if(!(M == e1_grid_desc_m_n.GetLength(I0) && Gemm1N == e1_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % Gemm0MPerBlock == 0 && N % Gemm0NPerBlock == 0 && K % Gemm0KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / Gemm0KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(Gemm0NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / Gemm0KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + // A0 desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultA0GridDescriptor_AK0_M_AK1(const A0GridDesc_M_K& a0_grid_desc_m_k) + { + const auto M = a0_grid_desc_m_k.GetLength(I0); + const auto K = a0_grid_desc_m_k.GetLength(I1); + + const auto A0K0 = K / A0K1; + + return transform_tensor_descriptor( + a0_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(A0K0, A0K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B0 desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultB0GridDescriptor_BK0_N_BK1(const B0GridDesc_N_K& b0_grid_desc_n_k) + { + const auto N = b0_grid_desc_n_k.GetLength(I0); + const auto K = b0_grid_desc_n_k.GetLength(I1); + + const auto B0K0 = K / B0K1; + + return transform_tensor_descriptor( + b0_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B0K0, B0K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // D0 desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n) + { + const auto M = d0_grid_desc_m_n.GetLength(I0); + const auto N = d0_grid_desc_m_n.GetLength(I1); + + constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_A0K1_B0K1 <= 4) || + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr auto mfma = MfmaSelector::selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N5 = mfma.group_size; + return transform_tensor_descriptor( + d0_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + M / Gemm0MPerBlock, Gemm0MXdlPerWave, Gemm0MWaves, Gemm0MPerXdl)), + make_unmerge_transform(make_tuple(N / Gemm0NPerBlock, + Gemm0NXdlPerWave, + Gemm0NWaves, + N3, + WaveSize / Gemm0NPerXdl, + N5))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); + } + + // B1 desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // C1 desc for destination in blockwise copy + __host__ __device__ static constexpr auto + MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const E1GridDesc_M_N& e1_grid_desc_m_n) + { + const auto M = e1_grid_desc_m_n.GetLength(I0); + const auto N = e1_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / Gemm0MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto e1_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e1_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e1_grid_desc_mblock_mperblock_nblock_nperblock; + } + // D0s desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]); + }, + Number{}); + } + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDescriptor_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to C1 matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2E1TileMap(const E1GridDesc_M_N& e1_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e1_grid_desc_m_n); + } + + using E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = + remove_cvref_t; + + using D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2E1TileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A0 and B0: be careful of alignment + static constexpr auto a0_block_desc_ak0_m_ak1 = + GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b0_block_desc_bk0_n_bk1 = + GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(A0K1, B0K1), B1K1); + + static constexpr auto a0_block_space_size_aligned = math::integer_least_multiple( + a0_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b0_block_space_size_aligned = math::integer_least_multiple( + b0_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a0_block_space_offset = 0; + static constexpr auto b0_block_space_offset = a0_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for C1 shuffle in LDS + static constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c1_block_space_size = + c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + using D0sGridPointer = decltype(MakeD0sGridPointer()); + using D1sGridPointer = decltype(MakeD1sGridPointer()); + + template + __device__ static void Run(const A0B0B1DataType* __restrict__ p_a0_grid, + const A0B0B1DataType* __restrict__ p_b0_grid, + D0sGridPointer p_d0s_grid, + const A0B0B1DataType* __restrict__ p_b1_grid, + D1sGridPointer p_d1s_grid, + E1DataType* __restrict__ p_e1_grid, + void* __restrict__ p_shared, + const A0ElementwiseOperation& a0_element_op, + const B0ElementwiseOperation& b0_element_op, + const CDE0ElementwiseOperation& cde0_element_op, + const B1ElementwiseOperation& b1_element_op, + const CDE1ElementwiseOperation& cde1_element_op, + const A0GridDesc_AK0_M_AK1& a0_grid_desc_ak0_m_ak1, + const B0GridDesc_BK0_N_BK1& b0_grid_desc_bk0_n_bk1, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + d1s_grid_desc_mblock_mperblock_nblock_nperblock, + const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e1_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2E1TileMap& block_2_e1tile_map) + { + const auto a0_grid_buf = make_dynamic_buffer( + p_a0_grid, a0_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b0_grid_buf = make_dynamic_buffer( + p_b0_grid, b0_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto e1_grid_buf = make_dynamic_buffer( + p_e1_grid, e1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + const auto d0s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d0s_grid[i], + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize()); + }, + Number{}); + const auto d1s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d1s_grid[i], + d1s_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_e1tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_e1tile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e1_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * Gemm0MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A0 matrix in LDS memory, dst of blockwise copy + constexpr auto a0_block_desc_ak0_m_ak1 = GetA0BlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B0 matrix in LDS memory, dst of blockwise copy + constexpr auto b0_block_desc_bk0_n_bk1 = GetB0BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A0 matrix blockwise copy + auto a0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + A0BlockTransferThreadClusterLengths_AK0_M_AK1, + A0BlockTransferThreadClusterArrangeOrder, + A0B0B1DataType, + A0B0B1DataType, + decltype(a0_grid_desc_ak0_m_ak1), + decltype(a0_block_desc_ak0_m_ak1), + A0BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + A0BlockTransferSrcVectorDim, + 2, + A0BlockTransferSrcScalarPerVector, + A0BlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemm0KPrefetchStage>( + a0_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a0_element_op, + a0_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B0 matrix blockwise copy + auto b0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B0BlockTransferThreadClusterLengths_BK0_N_BK1, + B0BlockTransferThreadClusterArrangeOrder, + A0B0B1DataType, + A0B0B1DataType, + decltype(b0_grid_desc_bk0_n_bk1), + decltype(b0_block_desc_bk0_n_bk1), + B0BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B0BlockTransferSrcVectorDim, + 2, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemm0KPrefetchStage>( + b0_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b0_element_op, + b0_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr auto lcm_A0K1_B0K1 = math::lcm(A0K1, B0K1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_A0K1_B0K1 <= 4) || + (is_same::value && lcm_A0K1_B0K1 <= 8) || + ((is_same::value || is_same::value) && + lcm_A0K1_B0K1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_A0K1_B0K1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm0 = BlockwiseGemmXdlops_v2< + BlockSize, + A0B0B1DataType, + Acc0DataType, + decltype(a0_block_desc_ak0_m_ak1), + decltype(b0_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a0_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b0_block_desc_bk0_n_bk1)), + Gemm0MPerBlock, + Gemm0NPerBlock, + Gemm0KPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm0NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer(); + + // LDS allocation for A0 and B0: be careful of alignment + auto a0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a0_block_space_offset, + a0_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b0_block_space_offset, + b0_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / A0K1, 0, 0); + constexpr auto b0_block_slice_copy_step = make_multi_index(Gemm0KPerBlock / B0K1, 0, 0); + const auto a0_block_reset_copy_step = + make_multi_index(-a0_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b0_block_reset_copy_step = + make_multi_index(-b0_grid_desc_bk0_n_bk1.GetLength(I0), Gemm0NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm0_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a0_grid_desc_ak0_m_ak1.GetLength(I0) * a0_grid_desc_ak0_m_ak1.GetLength(I2)) / + Gemm0KPerBlock); + + // + // set up Gemm1 + // + + // Acc0 matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm0.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // d0 matrix threadwise copy + constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId + I1, // NBlockID + m0, // MRepeat + n0, // NRepeat + m1, // MWaveId + n1, // NWaveId + m2, // MPerXdl + n2, // NGroupNum + n3, // NInputNum + n4)); // registerNum + + auto d0s_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer< + AddressSpaceEnum::Vgpr, + A0B0B1DataType, + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), + true>{}; + }, + Number{}); + + const auto wave_id = GetGemm0WaveIdx(); + const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 + + static_assert(CDE0BlockTransferSrcScalarPerVector <= n4, + "vector load must be not greater than n4"); + static_assert(n4 % CDE0BlockTransferSrcScalarPerVector == 0); + + auto d0s_threadwise_copy = generate_tuple( + [&](auto i) { + return ThreadwiseTensorSliceTransfer_v2< + A0B0B1DataType, + A0B0B1DataType, + decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), + decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, // CDE0BlockTransferSrcVectorDim + CDE0BlockTransferSrcScalarPerVector, + 1, + false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(block_work_idx[I0], // MBlockId + 0, // NBlockId + 0, // mrepeat + 0, // nrepeat + wave_id[I0], // MWaveId + wave_id[I1], // NWaveId + wave_m_n_id[I1], // MPerXdl + 0, // group + wave_m_n_id[I0], // NInputIndex + 0)); // register number + }, + Number{}); + // acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc0_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc0_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc0_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto Acc0N3 = + blockwise_gemm0.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = make_tuple( + Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + Acc0DataType, + A0B0B1DataType, + decltype(acc0_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{tensor_operation::element_wise::PassThrough{}}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + A0B0B1DataType, + A0B0B1DataType, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + 1>(b1_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b0_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; + + auto blockwise_gemm1 = BlockwiseGemmXdlops_v2< + BlockSize, + A0B0B1DataType, + Acc1DataType, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + Gemm0MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + Gemm0MPerXdl, + Gemm0NPerXdl, + Gemm0MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + false, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * XdlopsGemm{} + .K0PerXdlops>{ // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto c1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); + + const index_t num_gemm1_k_block_outer_loop = + b0_grid_desc_bk0_n_bk1.GetLength(I1) / Gemm0NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = Gemm0NPerBlock / Gemm1KPerBlock; + + // Initialize C1 + c1_thread_buf.Clear(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + // gemm0 + gridwise_gemm0_pipeline.template Run(a0_grid_desc_ak0_m_ak1, + a0_block_desc_ak0_m_ak1, + a0_blockwise_copy, + a0_grid_buf, + a0_block_buf, + a0_block_slice_copy_step, + b0_grid_desc_bk0_n_bk1, + b0_block_desc_bk0_n_bk1, + b0_blockwise_copy, + b0_grid_buf, + b0_block_buf, + b0_block_slice_copy_step, + blockwise_gemm0, + acc0_thread_buf, + num_k_block_main_loop); + // multiple d + if constexpr(NumD0Tensor) + { + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc0_thread_buf(i); }, + Number<2>{}); + + unpack2(cde0_element_op, dst_data_refs, src_data_refs); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); + }); + } + else + { + static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); }); + } + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for gemm0 LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc0_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc0_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, c1_thread_buf); + } + } // end gemm1 + + a0_blockwise_copy.MoveSrcSliceWindow(a0_grid_desc_ak0_m_ak1, + a0_block_reset_copy_step); // rewind K + b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc_bk0_n_bk1, + b0_block_reset_copy_step); // rewind K and step N + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C1 and write out + { + static_assert(Gemm0MXdlPerWave % C1ShuffleGemm0MXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % C1ShuffleGemm0NXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = Gemm0MPerBlock / (Gemm0MXdlPerWave * Gemm0MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * Gemm0NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm1.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm1.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetC1ShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c1_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (Gemm0MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = Gemm0MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (Gemm0NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = Gemm0NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM C1 matrix starting index + const auto c1_thread_mtx_on_block = + blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c1_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c1_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c1_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c1_d1s_desc_refs = concat_tuple_of_reference( + tie(c1_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return d1s_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c1_d1s_buf_refs = concat_tuple_of_reference( + tie(c1_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return d1s_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c1_d1s_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // shuffle: blockwise copy C from LDS to global + auto cde1_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(C1ShuffleDataType{}), D1sDataType{})), + Tuple, + decltype(c1_d1s_desc_refs), + decltype(tie(e1_grid_desc_mblock_mperblock_nblock_nperblock)), + CDE1ElementwiseOperation, + Sequence(E1GlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray + // type + Sequence<1, + C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl, + 1, + C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * + Gemm0NPerXdl>, // BlockSliceLengths, + CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDE1ShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c1_d1s_desc_refs, + idx_c1_d1s_block_begin, + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde1_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c1_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_e1_global = SpaceFillingCurve< + Sequence<1, Gemm0MPerBlock, 1, Gemm1NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + C1ShuffleGemm0MXdlPerWavePerShuffle * MWave * Gemm0MPerXdl, + 1, + C1ShuffleGemm0NXdlPerWavePerShuffle * NWave * Gemm0NPerXdl>>{}; + + constexpr index_t num_access = sfc_c1_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_e1_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c1_thread_copy_vgpr_to_lds.Run(c1_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c1_vgpr.GetIndexTupleOfNumber(access_id), + c1_thread_buf, + c1_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c1_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde1_shuffle_block_copy_lds_to_global.Run( + c1_d1s_desc_refs, + c1_d1s_buf_refs, + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e1_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto e1_global_step = sfc_e1_global.GetForwardStep(access_id); + + // move on D1s + static_for<0, NumD1Tensor, 1>{}([&](auto i) { + cde1_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + c1_d1s_desc_refs, i + I1, e1_global_step); + }); + + // move on C + cde1_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + tie(e1_grid_desc_mblock_mperblock_nblock_nperblock), I0, e1_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..0f2085525fd0 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,1329 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +template +struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr index_t NumD0Tensor = D0sDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + // Gemm0 + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); + static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); + + // Gemm1 + static constexpr auto B1K0 = Number{}; + static constexpr auto B1K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned) * + sizeof(FloatGemmAcc); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const C1GridDesc_M_N& c1_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); + + if(!(M == c1_grid_desc_m_n.GetLength(I0) && Gemm1N == c1_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c1_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeC1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const C1GridDesc_M_N& c1_grid_desc_m_n) + { + const auto M = c1_grid_desc_m_n.GetLength(I0); + const auto N = c1_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c1_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const C1GridDesc_M_N& c1_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c1_grid_desc_m_n); + } + + __device__ static auto GetGemm0WaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + constexpr auto WaveSize = MfmaSelector::selected_mfma.wave_size; + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(Gemm0MWaves, Gemm0NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetGemm0WaveMNIdx(const index_t thread_id) + { + constexpr auto WaveSize = MfmaSelector::selected_mfma.wave_size; + constexpr auto wave_threadid_to_mn_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(WaveSize / MPerXdl, MPerXdl))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return wave_threadid_to_mn_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + static constexpr auto MakeD0sGridPointer() + { + return generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + // D0 desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0GridDesc_M_N& d0_grid_desc_m_n) + { + const auto M = d0_grid_desc_m_n.GetLength(I0); + const auto N = d0_grid_desc_m_n.GetLength(I1); + + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr auto mfma = + MfmaSelector:: + selected_mfma; + constexpr auto N3 = mfma.num_groups_per_blk; + constexpr auto N4 = mfma.num_input_blks; + constexpr auto N5 = mfma.group_size; + return transform_tensor_descriptor( + d0_grid_desc_m_n, + make_tuple(make_unmerge_transform( + make_tuple(M / MPerBlock, MXdlPerWave, Gemm0MWaves, MPerXdl)), + make_unmerge_transform( + make_tuple(N / NPerBlock, NXdlPerWave, Gemm0NWaves, N3, N4, N5))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7, 8, 9>{})); + } + + // D0s desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(const D0sGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeGemm0D0GridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + using D0sGridPointer = decltype(MakeD0sGridPointer()); + using D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 = + remove_cvref_t; + + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_space_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + D0sGridPointer p_d0s_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const C0DEElementwiseOperation& c0de_element_op, + const B1ElementwiseOperation& b1_element_op, + const C1DEElementwiseOperation& c1de_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5& + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + const Block2CTileMap& block_2_ctile_map, + const C0MatrixMask& c0_matrix_mask) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + const auto d0s_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_d0s_grid[i], + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t gemm1_n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + const auto a_block_reset_copy_step = + make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b_block_reset_copy_step = + make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // + // set up Gemm1 + // + + // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // d0 matrix threadwise copy + constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 = + make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId + I1, // NBlockID + m0, // MRepeat + n0, // NRepeat + m1, // MWaveId + n1, // NWaveId + m2, // MPerXdl + n2, // NGroupNum + n3, // NInputNum + n4)); // registerNum + + auto d0s_thread_buf = generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + return StaticBuffer< + AddressSpaceEnum::Vgpr, + D0DataType, + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5.GetElementSpaceSize(), + true>{}; + }, + Number{}); + + const auto wave_id = GetGemm0WaveIdx(); + const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63 + + auto d0s_threadwise_copy = generate_tuple( + [&](auto i) { + using D0DataType = remove_cvref_t>; + return ThreadwiseTensorSliceTransfer_v2< + D0DataType, + D0DataType, + decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]), + decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5), + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + D0sTransferSrcScalarPerVector, + 1, + false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(block_work_idx[I0], // MBlockId + 0, // NBlockId + 0, // mrepeat + 0, // nrepeat + wave_id[I0], // MWaveId + wave_id[I1], // NWaveId + wave_m_n_id[I1], // MPerXdl + 0, // group + wave_m_n_id[I0], // NInputIndex + 0)); // register number + }, + Number{}); + // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto AccN3 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = + make_tuple(Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{tensor_operation::element_wise::PassThrough{}}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc_bk0_n_bk1, + make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; + + auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + true, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); + + // + // Blockwise softmax + // + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + + // get acc0 8D thread cluster + constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() / + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0); + constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1); + constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2); + constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3); + constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4); + constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5); + constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6); + constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7); + + // get acc0 thread map + constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_m_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4)); + constexpr auto thread_slice_desc_m_n = + make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + const index_t num_gemm1_k_block_outer_loop = + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + + // Initialize C + StaticBuffer + c_thread_buf; + c_thread_buf.Clear(); + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + auto n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) + { + continue; + } + // gemm0 + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + acc_thread_buf, + num_k_block_main_loop); + // multiple d + if constexpr(NumD0Tensor) + { + static_assert(NXdlPerWave == n0); + static_assert(MXdlPerWave == m0); + + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + d0s_grid_buf[i], + d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0), + d0s_thread_buf(i)); + }); + static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return acc_thread_buf(i); }, + Number<2>{}); + + unpack2(c0de_element_op, dst_data_refs, src_data_refs); + }); + static_for<0, NumD0Tensor, 1>{}([&](auto i) { + d0s_threadwise_copy(i).MoveSrcSliceWindow( + d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i], + make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0)); + }); + } + else + { + static_for<0, acc_thread_buf.Size(), 1>{}( + [&](auto i) { c0de_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); + } + + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 8d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + // 8d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + constexpr auto M0 = c_block_lengths[I0]; + constexpr auto N0 = c_block_lengths[I1]; + constexpr auto M1 = c_block_lengths[I2]; + constexpr auto N1 = c_block_lengths[I3]; + constexpr auto M2 = c_block_lengths[I4]; + constexpr auto N2 = c_block_lengths[I5]; + constexpr auto N3 = c_block_lengths[I6]; + constexpr auto N4 = c_block_lengths[I7]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( + Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)), + make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto n_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto n_global = n_local + n_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) + { + acc_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + }); + } + + block_sync_lds(); // wait for lds read in gemm0 blockwise gemm + + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, + a_block_reset_copy_step); // rewind K + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, + b_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + C1DEElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c1de_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp new file mode 100755 index 000000000000..1754e07e6a2b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -0,0 +1,1604 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +// Gemm0: A [M x K] x B0 [K x L] = Acc [M x L] +// Gemm1: Acc [M x L] x B1 [L x N] = C [M x N] +template +struct GridwiseBatchedGemmSoftmaxGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto L0PerBlock = LTilePerBlock / L1Value; + static constexpr auto AL0 = Number{}; + static constexpr auto AL1 = Number{}; + static constexpr auto BL0 = Number{}; + static constexpr auto BL1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = 16; + static constexpr auto WmmaL = 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / AK1; + constexpr auto max_lds_align = AK1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / AK1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + AK1), + make_tuple(Number{} * Number{} * AK1, + Number{} * AK1, + Number{} * AK1, + AK1, + AK1, + AK1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeB0BlockDescriptor() + { + constexpr auto b0_block_desc = [&]() { + if constexpr(B0EnableLds) + { + // K0->L->BK1 Per Block + constexpr auto K0PerBlock = KPerBlock / BK1; + constexpr auto max_lds_align = BK1; + + if constexpr(B0BlockLdsExtraL) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / BK1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BK1), + make_tuple(Number{} * Number{} * BK1, + Number{} * BK1, + Number{} * BK1, + BK1, + BK1, + BK1, + I1)); + } + }(); + + return b0_block_desc; + } + + __host__ __device__ static constexpr auto MakeB1BlockDescriptor() + { + constexpr auto b1_block_desc = [&]() { + if constexpr(B1EnableLds) + { + // L0->N->BL1 Per Block + constexpr auto max_lds_align = BL1; + + if constexpr(B1BlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BL1), + make_tuple(Number{} * BL1, BL1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BL1), max_lds_align); + } + } + else + { + constexpr auto LWmmaPerblock = LPerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL / 2 / BL1; + // LWmma->NRepeat->MWave->L0PerWmma->LRow->MPerWmma->L1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + BL1), + make_tuple(Number{} * Number{} * BL1, + Number{} * BL1, + Number{} * BL1, + BL1, + BL1, + BL1, + I1)); + } + }(); + + return b1_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / AK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB0BlockSliceCopyStep() + { + constexpr auto b0_block_copy_step = [&]() { + if constexpr(B0EnableLds) + { + constexpr auto K0PerBlock = KPerBlock / BK1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b0_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeB1BlockSliceCopyStep() + { + constexpr auto b1_block_copy_step = [&]() { + if constexpr(B1EnableLds) + { + return make_multi_index(L0PerBlock, 0, 0); + } + else + { + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + + return make_multi_index(LWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b1_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); + constexpr auto A_KRow = I1; + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&) + { + + constexpr auto b0_wave_desc = [&]() { + if constexpr(B0EnableLds) + { + // BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1 + constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + B0BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = B0BlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = B0BlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = B0BlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b0_wave_desc; + } + + template + __host__ __device__ static constexpr auto + MakeA1WaveDescriptor_L0_M0_M1_M2_L1(const A1BlockDesc_AL0_M_AL1&) + { + constexpr index_t A_L0 = A1BlockDesc_AL0_M_AL1{}.GetLength(I0); + constexpr index_t A_L1 = A1BlockDesc_AL0_M_AL1{}.GetLength(I2); + constexpr auto A_LRow = I1; + return transform_tensor_descriptor( + A1BlockDesc_AL0_M_AL1{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_LRow)), + make_unmerge_transform(make_tuple(Number{}, I1, I1)), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + template + __host__ __device__ static constexpr auto MakeB1WaveDescriptor(const B1BlockDesc_&) + { + + constexpr auto b1_wave_desc = [&]() { + if constexpr(B1EnableLds) + { + // BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1 + constexpr auto B_L0 = B1BlockDesc_{}.GetLength(I0); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_LRow = I2; +#else + constexpr auto B_LRow = I1; +#endif + return transform_tensor_descriptor( + B1BlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_LRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + constexpr auto LWmma = B1BlockDesc_{}.GetLength(I0); + constexpr auto L0PerWmma = B1BlockDesc_{}.GetLength(I3); + constexpr auto B_LRow = B1BlockDesc_{}.GetLength(I4); + constexpr auto B_L1 = B1BlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b1_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + const index_t gemm0_bytes_end = + (SharedMemTrait::a_block_space_size_aligned * sizeof(ADataType) + + SharedMemTrait::b0_block_space_size_aligned * sizeof(B0DataType)); + + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + + SharedMemTrait::b1_block_space_size_aligned * sizeof(B1DataType)); + + const index_t softmax_bytes_end = + SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned * sizeof(Acc0DataType); + + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(CShuffleDataType); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (LPerBlock % (LPerWmma * LRepeat)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetB0ProblemsizeLK = [&]() { + if constexpr(B0EnableLds) + { + return make_tuple(b0_grid_desc.GetLength(I1), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * + b0_grid_desc.GetLength(I5), + b0_grid_desc.GetLength(I0) * b0_grid_desc.GetLength(I3) * + b0_grid_desc.GetLength(I4) * b0_grid_desc.GetLength(I6)); + } + }; + + const auto GetB1ProblemsizeNL = [&]() { + if constexpr(B1EnableLds) + { + return make_tuple(b1_grid_desc.GetLength(I1), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b1_grid_desc.GetLength(I1) * b1_grid_desc.GetLength(I2) * + b1_grid_desc.GetLength(I5), + b1_grid_desc.GetLength(I0) * b1_grid_desc.GetLength(I3) * + b1_grid_desc.GetLength(I4) * b1_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto L = GetB0ProblemsizeLK()(I0); + const auto K = GetAProblemsizeMK()[I1]; + const auto N = GetB1ProblemsizeNL()(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + { + printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + return false; + } + + if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) + { + printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + printf("GridwiseOp: outer loop unsupport\n"); + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(LPerBlock % LTilePerBlock == 0)) + { + printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + return false; + } + + const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + printf("GridwiseOp: inner loop unsupport\n"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = math::integer_divide_ceil(K, KPerBlock); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), BL1); + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b0_block_space_size_aligned = + B0EnableLds ? math::integer_least_multiple( + MakeB0BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + static constexpr auto b1_block_space_size_aligned = + B1EnableLds ? math::integer_least_multiple( + MakeB1BlockDescriptor().GetElementSpaceSize(), max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b0_block_space_offset = a_block_space_size_aligned; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + // Feature to add, IntraThread Reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const B0DataType* __restrict__ p_b0_grid, + const B1DataType* __restrict__ p_b1_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const B0GridDesc& b0_grid_desc, + const B1GridDesc& b1_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const B0ElementwiseOperation& b0_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const C0MatrixMask& c0_matrix_mask, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b0_grid_buf = make_dynamic_buffer( + p_b0_grid, b0_grid_desc.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// set up Gemm0 +/*******************************************************************************/ + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b0_block_desc = MakeB0BlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto AK0PerBlock = KPerBlock/ AK1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ AElementwiseOperation, +/* typename DstElementwiseOperation, */ ck::tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> WmmaK/K1 -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/AK1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b0_block_trait = [&](){ + if constexpr(B0EnableLds) + { + auto b0_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b0_block_space_offset, + SharedMemTrait::b0_block_space_size_aligned); + + auto b0_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B0BlockTransferThreadClusterLengths_K0_L_K1, + B0BlockTransferThreadClusterArrangeOrder, + B0DataType, + B0DataType, + decltype(b0_grid_desc), + decltype(b0_block_desc), + B0BlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + B0BlockTransferSrcVectorDim, + 2, + B0BlockTransferSrcScalarPerVector, + B0BlockTransferDstScalarPerVector_K1, + 1, + 1, + B0ThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b0_grid_desc, + make_multi_index(0, 0, 0), + b0_element_op, + b0_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> LRepeat -> LWaves -> KRow -> LPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/BK1Value; + auto b0_block_buf = make_static_buffer( + b0_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b0_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B0BlockTransferSrcScalarPerVector, + B0ThreadTransferSrcResetCoordinateAfterRun, + true>( + b0_grid_desc, + make_multi_index(0, + 0/(LWaves * LPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b0_block_buf, b0_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b0_block_buf = b0_block_trait()[I0]; + auto b0_blockwise_copy = b0_block_trait()[I1]; + +/*******************************************************************************/ + // Gemm0 + constexpr auto KPack = math::integer_least_multiple(math::integer_least_multiple(AK1Value,BK1Value), WmmaK); + + auto blockwise_gemm0 = BlockwiseGemmWMMA< + BlockSize, + ADataType, + B0DataType, + Acc0DataType, + decltype(MakeAWaveDescriptor(a_block_desc)), + decltype(MakeB0WaveDescriptor(b0_block_desc)), + MPerBlock, + LPerBlock, + KPerBlock, + MPerWmma, + LPerWmma, + MRepeat, + LRepeat, + KPack, + AEnableLds, + B0EnableLds, + true>{}; // C' = B' x A' + + + // Prepare Register for A*B0 matrix + auto acc0_thread_buf = blockwise_gemm0.GetCThreadBuffer(); + + constexpr auto acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto mrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto mwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto mthreadpersubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto lrepeat = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto lwave = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto lsubgroup = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto laccvgprs = acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto acc0_thread_desc_l0perblock_mperblock_l1 = transform_tensor_descriptor( + acc0_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(lrepeat, lwave, lsubgroup)), + make_merge_transform_v3_division_mod(make_tuple(mrepeat, mwave, mthreadpersubgroup)), + make_pass_through_transform(laccvgprs)), + make_tuple(Sequence<3, 4, 5>{}, Sequence<0, 1, 2>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b0_block_slice_copy_step = MakeB0BlockSliceCopyStep(); + + const auto a_block_reset_copy_step = [&](){ + if constexpr(AEnableLds){ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0); + } + else{ + return make_multi_index(-a_grid_desc.GetLength(I0), 0, 0, 0, 0, 0, 0); + } + }(); + + const auto b0_block_reset_copy_step = [&](){ + if constexpr(B0EnableLds){ + return make_multi_index(-b0_grid_desc.GetLength(I0), LPerBlock, 0); + } + else{ + return make_multi_index(-b0_grid_desc.GetLength(I0), LRepeat, 0, 0, 0, 0, 0); + } + }(); + + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); +/*******************************************************************************/ +// softmax +/*******************************************************************************/ + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + // get acc0 7D thread cluster + constexpr auto thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths() / + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + constexpr auto t_mrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I0); + constexpr auto t_mwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I1); + constexpr auto t_mthreadpersubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I2); + constexpr auto t_lrepeat = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I3); + constexpr auto t_lwave = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I4); + constexpr auto t_lsubgroup = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I5); + constexpr auto t_laccvgprs = thread_cluster_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.At(I6); + // get acc0 thread map + constexpr auto m0_l_m1_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(t_mrepeat * t_mwave, t_mthreadpersubgroup)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_l_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform( + make_tuple(t_mrepeat * t_mwave, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs, t_mthreadpersubgroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_l_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_l_m1_to_m_l_adaptor, threadid_to_m0_l_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(t_mrepeat * t_mwave * t_mthreadpersubgroup, t_lrepeat * t_lwave * t_lsubgroup * t_laccvgprs)); + + constexpr auto thread_slice_desc_m_l = make_naive_tensor_descriptor_packed( + make_tuple(mrepeat * mwave * mthreadpersubgroup, lrepeat * lwave * lsubgroup * laccvgprs)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); +/*******************************************************************************/ +// set up Gemm1 +/*******************************************************************************/ + // Acc0 thread buffer -> A1 thread buffer -> blockwise gemm + // A1 matrix in VGPR + constexpr auto A1ThreadSlice_L0PerBlock_MPerBlock_L1 = make_tuple( + Number{}, + Number{}, + Number{}); + + constexpr auto A1ThreadSliceL0PerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I0]; + constexpr auto A1ThreadSliceMPerBlock = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I1]; + constexpr auto A1ThreadSliceL1 = A1ThreadSlice_L0PerBlock_MPerBlock_L1[I2]; + + constexpr auto a1_thread_desc_l0perblock_mperblock_l1 = make_naive_tensor_descriptor( + make_tuple(A1ThreadSliceL0PerBlock, A1ThreadSliceMPerBlock, A1ThreadSliceL1), + make_tuple(A1ThreadSliceMPerBlock * A1ThreadSliceL1, A1ThreadSliceL1, I1)); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + Acc0DataType, + ADataType, + decltype(acc0_thread_desc_l0perblock_mperblock_l1), + decltype(a1_thread_desc_l0perblock_mperblock_l1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2>, + 2, + laccvgprs>{tensor_operation::element_wise::PassThrough{}}; + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_l0perblock_mperblock_l1.GetElementSpaceSize()); + + constexpr auto b1_block_desc = MakeB1BlockDescriptor(); + + auto b1_block_trait = [&](){ + if constexpr(B1EnableLds) + { + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + SharedMemTrait::b1_block_space_size_aligned); + + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1< ThisThreadBlock, +/* typename SrcElementwiseOperation, */ B1ElementwiseOperation, +/* typename DstElementwiseOperation, */ tensor_operation::element_wise::PassThrough, +/* InMemoryDataOperationEnum DstInMemOp, */ InMemoryDataOperationEnum::Set, +/* typename BlockSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ B1BlockTransferThreadClusterLengths_L0_N_L1, +/* typename ThreadClusterArrangeOrder, */ B1BlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ B1DataType, +/* typename DstData, */ B1DataType, +/* typename SrcDesc, */ decltype(b1_grid_desc), +/* typename DstDesc, */ decltype(b1_block_desc), +/* typename SrcDimAccessOrder, */ B1BlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<1, 0, 2>, +/* index_t SrcVectorDim, */ B1BlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ B1BlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ B1BlockTransferDstScalarPerVector_L1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ B1ThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto LWmmaPerBlock = LTilePerBlock / WmmaL; + constexpr auto L0PerWmma = WmmaL/2/L1Value; + auto b1_block_buf = make_static_buffer( + b1_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b1_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + B1BlockTransferSrcScalarPerVector, + B1ThreadTransferSrcResetCoordinateAfterRun, + true>( + b1_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b1_block_buf, b1_blockwise_copy); + } + }; + + auto b1_block_buf = b1_block_trait()[I0]; + auto b1_blockwise_copy = b1_block_trait()[I1]; + + constexpr auto b1_block_slice_copy_step = MakeB1BlockSliceCopyStep(); + + auto blockwise_gemm1 = + BlockwiseGemmWMMA{make_tuple(0, 0, 0, 0, 0, 0)}; + + auto acc1_thread_buf = blockwise_gemm1.GetCThreadBuffer(); + + const auto L = [&](){ + if constexpr(B0EnableLds){ + return b0_grid_desc.GetLength(I1); + } + else{ + return b0_grid_desc.GetLength(I1) * b0_grid_desc.GetLength(I2) * b0_grid_desc.GetLength(I5); + } + }(); + + const index_t num_gemm1_l_block_outer_loop = L / LPerBlock; + constexpr index_t num_gemm1_l_block_inner_loop = LPerBlock / LTilePerBlock; + + // Initialize C + StaticBuffer c_thread_buf; + c_thread_buf.Clear(); + +/*******************************************************************************/ + // + // Kernel Main Stage + // + // Flash Attention + // Dao, Tri, et al. "Flashattention: Fast and memory-efficient exact attention with io-awareness." arXiv preprint arXiv:2205.14135 (2022). + index_t gemm1_l_block_outer_index = 0; + // Outer loop, along GEMM_L + // Inner loop, along GEMM_K + do{ + auto l_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_l_block_outer_index * LPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, l_block_data_idx_on_grid, MPerBlock, LPerBlock)) + { + continue; + } + // gemm0 start, A-B swaped + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b0_grid_desc, + b0_block_desc, + b0_blockwise_copy, + b0_grid_buf, + b0_block_buf, + b0_block_slice_copy_step, + blockwise_gemm0, + acc0_thread_buf, + KBlockMainLoop); + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 7d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm0.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + // 7d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm0.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs().GetLengths(); + + constexpr auto MREPEAT = c_block_lengths[I0]; + constexpr auto MWAVE = c_block_lengths[I1]; + constexpr auto MTHREADSubGroup = c_block_lengths[I2]; + constexpr auto LREPEAT = c_block_lengths[I3]; + constexpr auto LWAVE = c_block_lengths[I4]; + constexpr auto LSUBGROUP = c_block_lengths[I5]; + constexpr auto LACCVGPRS = c_block_lengths[I6]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm0.CalculateCThreadOriginDataIndex7D( + Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_l_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MREPEAT, MWAVE, MTHREADSubGroup)), + make_unmerge_transform(make_tuple(LREPEAT, LWAVE, LSUBGROUP, LACCVGPRS))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto l_local = block_idx_to_m_l_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto l_global = l_local + l_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, l_global)) + { + acc0_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + else + { + acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); + } + }); + } + else + { static_for<0, acc0_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc0_thread_buf(i), acc0_thread_buf[i]); }); + } + + block_sync_lds(); + // Tiled softmax start + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc0_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + + // main body + if constexpr(num_gemm1_l_block_inner_loop > 1) + { + static_for<0, num_gemm1_l_block_inner_loop - 1, 1>{}([&](auto i) { + // Data cast from Acc0DataType to ADataType happen here + a1_blockwise_copy.Run(acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple(Number{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc, b1_grid_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc0_thread_desc_l0perblock_mperblock_l1, + make_tuple( + Number<(num_gemm1_l_block_inner_loop - 1) * A1ThreadSliceL0PerBlock>{}, I0, I0), + acc0_thread_buf, + a1_thread_desc_l0perblock_mperblock_l1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + constexpr auto c_mrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I0); + constexpr auto c_mwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I1); + constexpr auto c_mthreadpersubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I2); + constexpr auto c_nrepeat = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I3); + constexpr auto c_nwave = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I4); + constexpr auto c_nsubgroup = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I5); + constexpr auto c_naccvgprs = c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs.GetLength(I6); + + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(c_mrepeat * c_mwave * c_mthreadpersubgroup, + c_nrepeat * c_nwave * c_nsubgroup * c_naccvgprs)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + Acc1DataType acc1 = acc1_thread_buf[I]; // P*V + Acc1DataType c = c_thread_buf[I]; // O + Acc1DataType c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, + a_block_reset_copy_step); // rewind K + b0_blockwise_copy.MoveSrcSliceWindow(b0_grid_desc, + b0_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + }while(++gemm1_l_block_outer_index < num_gemm1_l_block_outer_loop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = + blockwise_gemm1.GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp = + blockwise_gemm1.GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I1); + constexpr auto MThreadPerSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I4); + constexpr auto NSubGroup = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I5); + constexpr auto NAccVgprs = c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MThreadPerSubGroup // MThreadPerSubGroup = MPerWmma + )), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NSubGroup, + NAccVgprs))), // NSubGroup * NAccVgprs = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2>{}, Sequence<>{}, Sequence<3, 4, 5, 6>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm1.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NSubGroup, NAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_mthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nsubgroup_naccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 8, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_mthreadpersubgroup_nrepeat_nwave_nsubgroup_naccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..33b9199ea553 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,1142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_softmax.hpp" + +namespace ck { + +/** + * @brief Gridwise gemm + softmax + gemm fusion + * + */ +template +struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle +{ + static_assert(LoopSched == LoopScheduler::Default, + "Non-default loop scheduler is currently not supported"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + // Gemm0 + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave); + static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave); + + // Gemm1 + static constexpr auto B1K0 = Number{}; + static constexpr auto B1K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + template + __host__ __device__ static constexpr auto + MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B1 matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(B1K0, Number{}, B1K1), + make_tuple(Number{} * B1K1, B1K1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned + + SharedMemTrait::b_block_space_size_aligned) * + sizeof(FloatAB); + const index_t gemm1_bytes_end = + (SharedMemTrait::b1_block_space_offset + SharedMemTrait::b1_block_space_size_aligned) * + sizeof(FloatAB); + const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset + + SharedMemTrait::reduction_space_size_aligned) * + sizeof(FloatGemmAcc); + const index_t c_block_bytes_end = + SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle); + + return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 && + Gemm1N % Gemm1NPerBlock == 0)) + { + return false; + } + + // check gemm0 gridwise gemm pipeline + const auto num_gemm0_k_loop = K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) + { + return false; + } + + // check gemm1 gridwise gemm pipeline + if(!(NPerBlock % Gemm1KPerBlock == 0)) + { + return false; + } + + const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / Gemm1NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + static constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + static constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + static constexpr auto b1_block_desc_bk0_n_bk1 = + GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1); + + static constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + static constexpr auto b1_block_space_size_aligned = math::integer_least_multiple( + b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned.value; + static constexpr auto b1_block_space_offset = 0; + + // LDS allocation for reduction + static constexpr index_t reduction_space_size_aligned = + math::integer_least_multiple(BlockSize, max_lds_align); + + static constexpr auto reduction_space_offset = 0; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + static constexpr auto c_block_space_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + }; + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + const FloatAB* __restrict__ p_b1_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const B1ElementwiseOperation& b1_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map, + const C0MatrixMask& c0_matrix_mask) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto b1_grid_buf = make_dynamic_buffer( + p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t gemm1_n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // + // set up Gemm0 + // + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + true, // SrcResetCoord + true, // DstResetCoord + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), // will loop over GemmN dimension + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + // Fused Gemm+Gemm pipeline + // for n in N0: + // for k in K0: + // acc[m][n] += A[m][k] * B0[k][n] + // acc1[m][o] += acc[m][n] * B1[n][o] + + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + true>{}; // TransposeC + + auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_offset, + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + const auto a_block_reset_copy_step = + make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0); + const auto b_block_reset_copy_step = + make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0); + + // gridwise GEMM pipeline + // Only supports LoopScheduler::Default + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // + // set up Gemm1 + // + + // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type + constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + + constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0); + + // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1 + // n0_n1_n2_n3 -> k0 + // m0_m1_m2 -> m + // n4 -> k1 + // NOTE: had to use merge_v3 or will spit out compilation errors + constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor( + acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(n0, n1, n2, n3)), + make_merge_transform_v3_division_mod(make_tuple(m0, m1, m2)), + make_pass_through_transform(n4)), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + // A1 matrix in AccVGPR + // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size + constexpr auto AccN3 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6); + + constexpr auto A1ThreadSlice_K0_M_K1 = + make_tuple(Number{}, Number{}, Number{}); + + constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0]; + constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1]; + constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2]; + constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor( + A1ThreadSlice_K0_M_K1, + make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1)); + + // B1 matrix in LDS memory, dst of blockwise copy + constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A1 matrix blockwise copy + auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic< + FloatGemmAcc, + FloatAB, + decltype(acc_thread_desc_k0_m_k1), + decltype(a1_thread_desc_k0_m_k1), + tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<1, 0, 2>, + 2, + n4>{tensor_operation::element_wise::PassThrough{}}; + + // B1 matrix blockwise copy + auto b1_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + B1BlockTransferThreadClusterLengths_BK0_N_BK1, + B1BlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b1_grid_desc_bk0_n_bk1), + decltype(b1_block_desc_bk0_n_bk1), + B1BlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + B1BlockTransferSrcVectorDim, + 2, + B1BlockTransferSrcScalarPerVector, + B1BlockTransferDstScalarPerVector_BK1, + 1, + 1, + B1ThreadTransferSrcResetCoordinateAfterRun, + true, // DstResetCoord + NumGemmKPrefetchStage>( + b1_grid_desc_bk0_n_bk1, + make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0), + b1_element_op, + b1_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + tensor_operation::element_wise::PassThrough{}); + + auto a1_thread_buf = make_static_buffer( + a1_thread_desc_k0_m_k1.GetElementSpaceSize()); + + // reuse LDS space for gemm0's b_block_buf + auto b1_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b1_block_space_offset, + b1_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size + // selected_mfma.k_per_blk <= Gemm1KPack + // + // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common + // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case + // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs + // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will + // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7]. + // therefore we may just as well assign Gemm1KPack = group_size + + constexpr index_t Gemm1KPack = + MfmaSelector::selected_mfma.group_size; + + auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2< + BlockSize, + FloatAB, + FloatGemmAcc, + decltype(a1_thread_desc_k0_m_k1), + decltype(b1_block_desc_bk0_n_bk1), + decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)), + decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)), + MPerBlock, + Gemm1NPerBlock, + Gemm1KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + Gemm1NXdlPerWave, + Gemm1KPack, + true, // TransposeC + Gemm1KPack, // AMmaKStride + Gemm1KPack * + XdlopsGemm{}.K0PerXdlops>{ + // BMmaKStride + make_tuple(0, 0, 0, 0)}; // A_origin + + auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); + + // + // Blockwise softmax + // + auto workspace_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::reduction_space_offset, + SharedMemTrait::reduction_space_size_aligned); + + // get acc0 8D thread cluster + constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() / + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0); + constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1); + constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2); + constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3); + constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4); + constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5); + constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6); + constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7); + + // get acc0 thread map + constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(tm0 * tm1, tm2)), + make_pass_through_transform(I1)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + const auto threadid_to_m_n_thread_cluster_adaptor = + chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor); + + // get acc0 2D thread cluster & 2D thread slice + constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4)); + constexpr auto thread_slice_desc_m_n = + make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4)); + + auto blockwise_softmax = BlockwiseSoftmax{}; + + const index_t num_gemm1_k_block_outer_loop = + b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock; + constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock; + + // Initialize C + StaticBuffer + c_thread_buf; + c_thread_buf.Clear(); + + // Initialize running sum and max of exponentiating row vectors + using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType; + SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new; + running_sum = 0; + running_sum_new = 0; + running_max = NumericLimits::Lowest(); + running_max_new = NumericLimits::Lowest(); + + // gemm1 K loop + index_t gemm1_k_block_outer_index = 0; + do + { + auto n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); + if(c0_matrix_mask.IsTileSkippable( + m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock)) + { + continue; + } + // gemm0 + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + acc_thread_buf, + num_k_block_main_loop); + + // do MNK padding or upper triangular masking + if constexpr(MaskOutUpperTriangle || PadN) + { + // 8d thread_desc in thread scope + constexpr auto c_thread_lengths = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + // 8d block_desc in block scope + constexpr auto c_block_lengths = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths(); + + constexpr auto M0 = c_block_lengths[I0]; + constexpr auto N0 = c_block_lengths[I1]; + constexpr auto M1 = c_block_lengths[I2]; + constexpr auto N1 = c_block_lengths[I3]; + constexpr auto M2 = c_block_lengths[I4]; + constexpr auto N2 = c_block_lengths[I5]; + constexpr auto N3 = c_block_lengths[I6]; + constexpr auto N4 = c_block_lengths[I7]; + + // works like multi-dimension static_for (static_ford), but provides both the linear + // index as well as n-d index + using Acc0TileIterator = SpaceFillingCurve< + decltype(c_thread_lengths), + typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type, + typename uniform_sequence_gen::type, + false>; // SnakeCurved + + auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D( + Number<0>{}, Number<0>{}, Number<0>{}, Number<0>{}); + + constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M0, M1, M2)), + make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5, 6, 7>{})); + + static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) { + auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin; + auto m_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0]; + auto n_local = + block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1]; + auto m_global = m_local + m_block_data_idx_on_grid; + auto n_global = n_local + n_block_data_idx_on_grid; + if(c0_matrix_mask.IsMaskedElement(m_global, n_global)) + { + acc_thread_buf(i) = -ck::NumericLimits::Infinity(); + } + else + { + acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); + } + }); + } + else + { + static_for<0, acc_thread_buf.Size(), 1>{}( + [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); }); + } + + block_sync_lds(); // wait for lds read in gemm0 blockwise gemm + + // softmax + SoftmaxBuf& max = blockwise_softmax.max_value_buf; + SoftmaxBuf& sum = blockwise_softmax.sum_value_buf; + + blockwise_softmax.Run(acc_thread_buf, workspace_buf); + + // TODO: may convert to log domain + running_max_new = mathext::max(max, running_max); + running_sum_new = mathext::exp(running_max - running_max_new) * running_sum + + mathext::exp(max - running_max_new) * sum; + + // gemm1 + { + // TODO: explore using dynamic buffer for a1 thread buffer + // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(), + // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that + // the A1 source buffer is static buffer holding the output of first GEMM and + // requires constexpr offset by design. Therefore, we pass tensor coordinate offset + // explicitly in Run() below. + + // Initialize acc1 + acc1_thread_buf.Clear(); + + // preload data into LDS + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + block_sync_lds(); // wait for reduction LDS read + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + + // main body + if constexpr(num_gemm1_k_block_inner_loop > 1) + { + static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) { + a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1, + make_tuple(Number{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + + block_sync_lds(); + + b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1, + b1_block_slice_copy_step); + + b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf); + }); + } + // tail + { + a1_blockwise_copy.Run( + acc_thread_desc_k0_m_k1, + make_tuple( + Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0), + acc_thread_buf, + a1_thread_desc_k0_m_k1, + make_tuple(I0, I0, I0), + a1_thread_buf); + + block_sync_lds(); + + gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf); + } + } // end gemm1 + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0); + constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1); + constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2); + constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3); + constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4); + constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5); + constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6); + constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7); + constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed( + make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4)); + constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0); + constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1); + + static_for<0, c_thread_buf_slice_m, 1>{}([&](auto iM) { + static_for<0, c_thread_buf_slice_n, 1>{}([&](auto iN) { + auto I = Number{}; + FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V + FloatGemmAcc c = c_thread_buf[I]; // O + FloatGemmAcc c_new = + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + + math::exp(max[iM] - running_max_new[iM]) * acc1) / + running_sum_new[iM]; // Formula by Dao et al., + // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1 + + c_thread_buf(I) = c_new; // O_new + }); + }); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1, + a_block_reset_copy_step); // rewind K + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1, + b_block_reset_copy_step); // rewind K and step N + + // update before next j iteration + running_max = running_max_new; + running_sum = running_sum_new; + + block_sync_lds(); // wait for gemm1 LDS read + } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp new file mode 100755 index 000000000000..ed1ffdd85765 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_backward_blockwise_welford.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_batchnorm_backward_with_blockwise_welford( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const XYGridDesc_M_K dx_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + long_index_t reduce_size, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) +{ + GridwiseBatchrNormBackwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, + dy_grid_desc_m_k, + dx_grid_desc_m_k, + scale_grid_desc_m, + dscale_dbias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + reduce_size, + num_k_block_tile_iteration, + epsilon, + p_x, + p_dy, + p_scale, + haveSavedMeanInvVar, + p_savedMean, + p_savedInvVar, + dy_elementwise_op, + p_dx, + p_dscale, + p_dbias); +}; + +template +struct GridwiseBatchNormBackwardWithBlockwiseWelford +{ + static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 && + MThreadSliceSize % DySrcVectorSize == 0 && + MThreadSliceSize % DxDstVectorSize == 0) || + (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 && + KThreadSliceSize % DySrcVectorSize == 0 && + KThreadSliceSize % DxDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XDyDxVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using BlockwiseReduce = PartitionedBlockwiseReduction; + + using ThreadwiseReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + // clang-format off + // Blockwise BatchNorm Backward + // Input: x, dy, scale, savedMean and savedInvVar (optional), reduce_size + // Output: dx, dscale, dbias + // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance) + // Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance) + // Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly + // clang-format on + __device__ static void Run(const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K dy_grid_desc_m_k, + const XYGridDesc_M_K dx_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + long_index_t reduce_size, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const DyDataType* const __restrict__ p_dy, + const ScaleDataType* const __restrict__ p_scale, + bool haveSavedMeanInvVar, + const MeanVarDataType* const __restrict__ p_savedMean, + const MeanVarDataType* const __restrict__ p_savedInvVar, + const DyElementwiseOp dy_elementwise_op, + DxDataType* const __restrict__ p_dx, + DscaleDbiasDataType* const __restrict__ p_dscale, + DscaleDbiasDataType* const __restrict__ p_dbias) + { + using ck::math::sqrt; + + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + x_thread_buf; + + StaticBuffer + dy_thread_buf; + + StaticBuffer + dx_thread_buf; + + // buffer of values of dy * (x-mean) * invVariance, used as input of Blockwise reduction + StaticBuffer + tmp1_thread_buf; + + StaticBuffer scale_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + StaticBuffer& + inv_var_thread_buf = var_thread_buf; + + StaticBuffer dscale_thread_buf; + StaticBuffer dbias_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_dscale_dbias_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DscaleDbiasDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dscale_dbias_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + + const auto x_global_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto dy_global_buf = make_dynamic_buffer( + p_dy, dy_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_buf = make_dynamic_buffer( + p_dx, dx_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + auto dscale_global_buf = make_dynamic_buffer( + p_dscale, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + auto dbias_global_buf = make_dynamic_buffer( + p_dbias, dscale_dbias_grid_desc_m.GetElementSpaceSize()); + + // clang-format off + // Step 1: calculating mean and inv-variance using welford method (if savedMean/savedInvVar not available), where inv-variance = 1/sqrt(epsilon+variance) + // clang-format on + + if(haveSavedMeanInvVar) + { + const auto mean_global_buf = make_dynamic_buffer( + p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + const auto inv_var_global_buf = make_dynamic_buffer( + p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_inv_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf); + + threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m, + inv_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + inv_var_thread_buf); + } + else + { + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + // calculate inv-variance as 1/sqrt(epsilon+variance) + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + inv_var_thread_buf(I) = + type_convert(1.0) / sqrt(var_thread_buf[I] + epsilon); + }); + + threadwise_x_load.SetSrcSliceOrigin( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + }; + + // clang-format off + // Step 2: reduction: dbias = sum(dy), dscale = sum(dy *(x-mean) * inv-variance) + // clang-format on + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dscale_thread_buf(I) = type_convert(0); + dbias_thread_buf(I) = type_convert(0); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dx_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + tmp1_thread_buf(Number{}) = norm_x * dy_thread_buf[Number{}]; + }); + }); + + ThreadwiseReduce::Reduce(tmp1_thread_buf, dscale_thread_buf); + ThreadwiseReduce::Reduce(dy_thread_buf, dbias_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + }; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, dscale_thread_buf(I)); + block_sync_lds(); + BlockwiseReduce::Reduce(reduce_work_buf, dbias_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dscale_thread_buf, + dscale_dbias_grid_desc_m, + dscale_global_buf); + + threadwise_dscale_dbias_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dbias_thread_buf, + dscale_dbias_grid_desc_m, + dbias_global_buf); + }; + + // clang-format off + // Step 3: calculating dx = 1/reduce_size * inv-variance * scale * (reduce_size * dy - dbias - dscale * (x - mean) * inv-variance)) elementwise-ly + // clang-format on + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k); + + AccDataType inv_reduce_size = + type_convert(1.0) / type_convert(reduce_size); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + inv_reduce_size * inv_var_thread_buf[iM] * scale_thread_buf[iM]; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + dy_elementwise_op(dy_thread_buf(Number{}), + dy_thread_buf[Number{}]); + + AccDataType norm_x = (x_thread_buf[Number{}] - mean_thread_buf[iM]) * + inv_var_thread_buf[iM]; + + AccDataType tmpVal = norm_x * dscale_thread_buf[iM]; + + dx_thread_buf(Number{}) = + multiplier * + (type_convert(reduce_size) * dy_thread_buf[Number{}] - + dbias_thread_buf[iM] - tmpVal); + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp new file mode 100755 index 000000000000..b6c83af13a18 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_batchnorm_forward_blockwise_welford.hpp @@ -0,0 +1,483 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math_v2.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_batchnorm_forward_with_blockwise_welford( + const XYGridDesc_M_K x_grid_desc_m_k, + const XYGridDesc_M_K y_grid_desc_m_k, + const ScaleBiasGridDesc_M scale_grid_desc_m, + const ScaleBiasGridDesc_M bias_grid_desc_m, + const MeanVarGridDesc_M mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) +{ + GridwiseBatchrNormForwardWithBlockwiseWelford_::Run(x_grid_desc_m_k, + y_grid_desc_m_k, + scale_grid_desc_m, + bias_grid_desc_m, + mean_var_grid_desc_m, + get_reduce_count_per_thread, + num_k_block_tile_iteration, + epsilon, + p_x, + p_scale, + p_bias, + y_elementwise_op, + p_y, + updateMovingAverage, + averageFactor, + resultRunningMean, + resultRunningVariance, + saveMeanInvVariance, + resultSaveMean, + resultSaveInvVariance); +}; + +template +struct GridwiseBatchNormForwardWithBlockwiseWelford +{ + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k, + const XYGridDesc_M_K& y_grid_desc_m_k, + const ScaleBiasGridDesc_M& scale_grid_desc_m, + const ScaleBiasGridDesc_M& bias_grid_desc_m, + const MeanVarGridDesc_M& mean_var_grid_desc_m, + const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const XDataType* const __restrict__ p_x, + const ScaleDataType* const __restrict__ p_scale, + const BiasDataType* const __restrict__ p_bias, + const YElementwiseOp y_elementwise_op, + YDataType* const __restrict__ p_y, + bool updateMovingAverage, + AccDataType averageFactor, + MeanVarDataType* const __restrict__ resultRunningMean, + MeanVarDataType* const __restrict__ resultRunningVariance, + bool saveMeanInvVariance, + MeanVarDataType* const __restrict__ resultSaveMean, + MeanVarDataType* const __restrict__ resultSaveInvVariance) + { + using ck::math::sqrt; + + StaticBuffer + x_thread_buf; + + StaticBuffer scale_thread_buf; + + StaticBuffer bias_thread_buf; + + StaticBuffer + y_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + y_elementwise_op); + + auto threadwise_scale_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + ScaleSrcVectorSize, + 1, + true>( + scale_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2, + 0, + BiasSrcVectorSize, + 1, + true>( + bias_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto scale_global_val_buf = make_dynamic_buffer( + p_scale, scale_grid_desc_m.GetElementSpaceSize()); + + const auto bias_global_val_buf = make_dynamic_buffer( + p_bias, bias_grid_desc_m.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y, y_grid_desc_m_k.GetElementSpaceSize()); + + // Step 1: do welford reduction to get mean and variance + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = get_reduce_count_per_thread(thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf, mean_thread_buf, var_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + // Step 2: do normalization and output y + + threadwise_scale_load.Run(scale_grid_desc_m, + scale_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + scale_thread_buf); + + threadwise_bias_load.Run(bias_grid_desc_m, + bias_global_val_buf, + thread_buffer_desc_m, + make_tuple(I0), + bias_thread_buf); + + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + AccDataType multiplier = + scale_thread_buf[Number{}] / sqrt(var_thread_buf[iM] + epsilon); + + AccDataType fused_mean_bias = + bias_thread_buf[Number{}] - mean_thread_buf[iM] * multiplier; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + // normalize + y_thread_buf(Number{}) = + x_thread_buf[Number{}] * multiplier + fused_mean_bias; + }); + }); + + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf, + y_grid_desc_m_k, + y_global_val_buf); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + + // Step 3: update the moving average of mean and variance (optional) + + if(updateMovingAverage && thread_k_cluster_id == 0) + { + StaticBuffer + running_mean_thread_buf; + StaticBuffer + running_var_thread_buf; + + auto running_mean_global_buf = make_dynamic_buffer( + resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto running_var_global_buf = make_dynamic_buffer( + resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_mean_var_load = + ThreadwiseTensorSliceTransfer_v2, + 0, + MeanVarSrcDstVectorSize, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize)); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_mean_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf); + + threadwise_mean_var_load.Run(mean_var_grid_desc_m, + running_var_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf); + + AccDataType oneMinusAverageFactor = type_convert(1.0) - averageFactor; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor + + mean_thread_buf[I] * averageFactor; + running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor + + var_thread_buf[I] * averageFactor; + }); + + auto threadwise_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_mean_thread_buf, + mean_var_grid_desc_m, + running_mean_global_buf); + + threadwise_mean_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + running_var_thread_buf, + mean_var_grid_desc_m, + running_var_global_buf); + }; + + // Step 4: save mean and inv-variance (optional) + + if(saveMeanInvVariance && thread_k_cluster_id == 0) + { + auto result_mean_global_buf = make_dynamic_buffer( + resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize()); + + auto result_inv_var_global_buf = make_dynamic_buffer( + resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize()); + + // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + var_thread_buf(I) = + type_convert(1.0f) / sqrt(epsilon + var_thread_buf[I]); + }); + + auto threadwise_mean_inv_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + MeanVarSrcDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + mean_var_grid_desc_m, + result_mean_global_buf); + + threadwise_mean_inv_var_store.Run(thread_buffer_desc_m, + make_tuple(I0), + var_thread_buf, + mean_var_grid_desc_m, + result_inv_var_global_buf); + }; + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp new file mode 100755 index 000000000000..13e9f7bd5ec8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d_scale.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const UnaryOperation unary_op, + const Scale scale_op) +{ + GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple, + out_grid_1d_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + elementwise_op, + unary_op, + scale_op); +} + +template +struct GridwiseElementwise_1D +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGrid1dDescTuple::Size() && + NumOutput == OutGrid1dDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, + const OutGrid1dDescTuple out_grid_1d_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const ElementwiseOperation elementwise_op, + const UnaryOperation unary_op, + const Scale scale_op) + { + const index_t thread_global_id = get_thread_global_1d_id(); + + auto in_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto out_thread_buf_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return StaticBuffer{}; + }, + Number{}); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1); + + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto thread_global_offset = make_multi_index(thread_global_id * MPerThread); + + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc_tuple[I0].GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * MPerThread; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InScalarPerVectorSeq::At( + I), // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc_tuple[I], + thread_global_offset}; + }, + Number{}); + + auto out_global_store_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return ThreadwiseTensorSliceTransfer_v1r3, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + OutScalarPerVectorSeq::At(I), + InMemoryDataOperationEnum::Set, + 1, + false>( + out_grid_1d_desc_tuple[I], thread_global_offset, PassThroughOp{}); + }, + Number{}); + + index_t num_iter = M / (loop_step); + do + { + static_for<0, NumInput, 1>{}([&](auto I) { + in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf_tuple(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I], + loop_step_index); + }); + + static_for<0, MPerThread, 1>{}([&](auto iM) { + // get reference to in data + auto uop_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + // get reference to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return out_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(unary_op, uop_data_refs, uop_data_refs); + + auto sop_in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + auto sop_out_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(scale_op, sop_out_data_refs, sop_in_data_refs); + + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { return in_thread_buf_tuple(I)(iM); }, + Number{}); + + unpack2(elementwise_op, out_data_refs, in_data_refs); + }); + + static_for<0, NumOutput, 1>{}([&](auto I) { + out_global_store_tuple(I).Run(thread_buffer_desc_m, + make_tuple(I0), + out_thread_buf_tuple[I], + out_grid_1d_desc_tuple[I], + out_global_buf_tuple(I)); + + out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I], + loop_step_index); + }); + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp new file mode 100755 index 000000000000..1326c5d62d5b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp @@ -0,0 +1,400 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op) +{ + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, + out_grid_desc_tuple, + p_in_global_tuple, + p_out_global_tuple, + block_2_tile_map, + elementwise_op); +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size) +{ + if(get_block_1d_id() < a_grid_size) + { + GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a, + out_grid_desc_tuple_a, + p_in_global_tuple_a, + p_out_global_tuple_a, + block_2_tile_map_a, + elementwise_op, + get_block_1d_id()); + } + else + { + GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b, + out_grid_desc_tuple_b, + p_in_global_tuple_b, + p_out_global_tuple_b, + block_2_tile_map_b, + elementwise_op, + get_block_1d_id() - a_grid_size); + } +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_elementwise_batched_dual( + const InAGridDescTuple in_grid_desc_tuple_a, + const InBGridDescTuple in_grid_desc_tuple_b, + const OutAGridDescTuple out_grid_desc_tuple_a, + const OutBGridDescTuple out_grid_desc_tuple_b, + const InADataTypePointerTuple p_in_global_tuple_a, + const InBDataTypePointerTuple p_in_global_tuple_b, + const OutADataTypePointerTuple p_out_global_tuple_a, + const OutBDataTypePointerTuple p_out_global_tuple_b, + const Block2TileMapA block_2_tile_map_a, + const Block2TileMapB block_2_tile_map_b, + const ElementwiseOperation elementwise_op, + const index_t a_grid_size, + const index_t batch_count_a, + const index_t batch_count_b, + const std::array input_batch_strides_a, + const std::array input_batch_strides_b, + const std::array output_batch_strides_a, + const std::array output_batch_strides_b) +{ + static_assert(InAGridDescTuple::Size() == NumInputsA && + InADataTypePointerTuple::Size() == NumInputsA); + static_assert(OutAGridDescTuple::Size() == NumOutputsA && + OutADataTypePointerTuple::Size() == NumOutputsA); + static_assert(InBGridDescTuple::Size() == NumInputsB && + InBDataTypePointerTuple::Size() == NumInputsB); + static_assert(OutBGridDescTuple::Size() == NumOutputsB && + OutBDataTypePointerTuple::Size() == NumOutputsB); + + const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id()); + + if(block_id < a_grid_size) + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a); + const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch); + + InADataTypePointerTuple p_in_global_with_offset_tuple; + OutADataTypePointerTuple p_out_global_with_offset_tuple; + + static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_in_global_with_offset_tuple(i) = + p_in_global_tuple_a.At(i) + + type_convert(input_batch_strides_a[i]) * g_idx; + }); + + static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_out_global_with_offset_tuple(i) = + p_out_global_tuple_a.At(i) + + type_convert(output_batch_strides_a[i]) * g_idx; + }); + + GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a, + out_grid_desc_tuple_a, + p_in_global_with_offset_tuple, + p_out_global_with_offset_tuple, + block_2_tile_map_a, + elementwise_op, + block_id); + } + else + { + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b); + const index_t g_idx = + __builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch); + + InBDataTypePointerTuple p_in_global_with_offset_tuple; + OutBDataTypePointerTuple p_out_global_with_offset_tuple; + + static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_in_global_with_offset_tuple(i) = + p_in_global_tuple_b.At(i) + + type_convert(input_batch_strides_b[i]) * g_idx; + }); + + static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_out_global_with_offset_tuple(i) = + p_out_global_tuple_b.At(i) + + type_convert(output_batch_strides_b[i]) * g_idx; + }); + + GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b, + out_grid_desc_tuple_b, + p_in_global_with_offset_tuple, + p_out_global_with_offset_tuple, + block_2_tile_map_b, + elementwise_op, + block_id - a_grid_size); + } +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, + const OutGridDescTuple out_grid_desc_tuple, + const InDataTypePointerTuple p_in_global_tuple, + const OutDataTypePointerTuple p_out_global_tuple, + const Block2TileMap block_2_tile_map, + const ElementwiseOperation elementwise_op, + const index_t batch_count, + const std::array input_batch_strides, + const std::array output_batch_strides) +{ + static_assert(InGridDescTuple::Size() == NumInputs && + InDataTypePointerTuple::Size() == NumInputs); + static_assert(OutGridDescTuple::Size() == NumOutputs && + OutDataTypePointerTuple::Size() == NumOutputs); + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + InDataTypePointerTuple p_in_global_with_offset_tuple; + OutDataTypePointerTuple p_out_global_with_offset_tuple; + + static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_in_global_with_offset_tuple(i) = + p_in_global_tuple.At(i) + type_convert(input_batch_strides[i]) * g_idx; + }); + + static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) { + p_out_global_with_offset_tuple(i) = + p_out_global_tuple.At(i) + type_convert(output_batch_strides[i]) * g_idx; + }); + + GridwiseElementwiseFunctor::Run(in_grid_desc_tuple, + out_grid_desc_tuple, + p_in_global_with_offset_tuple, + p_out_global_with_offset_tuple, + block_2_tile_map, + elementwise_op); +} + +template +struct GridwiseElementwise +{ + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + static constexpr index_t NumOutput = OutDataTypePointerTuple::Size(); + + static_assert(NumInput == InScalarPerVectorSeq::Size() && + NumOutput == OutScalarPerVectorSeq::Size() && + NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(), + "Tuple size is inconsistent with the number of in/out!"); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static_assert((SrcVectorDim == I0 || SrcVectorDim == I1) && + (DstVectorDim == I0 || DstVectorDim == I1), + "Vector dim must be equal to 0 or 1."); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple, + const OutGridDescTuple& out_grid_desc_tuple, + const InDataTypePointerTuple& p_in_global_tuple, + const OutDataTypePointerTuple& p_out_global_tuple, + const Block2TileMap& block_2_tile_map, + const ElementwiseOperation& elementwise_op, + const index_t block_id = get_block_1d_id()) + { + + constexpr auto src_datas = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return DataType{}; + }, + Number{}); + + constexpr auto dst_datas = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_pointer_t; + + return DataType{}; + }, + Number{}); + + const auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto out_global_buf_tuple = generate_tuple( + [&](auto I) { + return make_dynamic_buffer( + p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id)); + + const index_t m0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock); + const index_t m1_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock); + const auto input_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); + const auto output_thread_grid_offset = generate_tuple( + [&](auto) { + return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid); + }, + Number{}); + + using ThisThreadBlock = ThisThreadBlock; + // If src and dst have same vector dim, then: + // M0 dim - for src and dst vector load/store + // else: + // M0 dim - for dst vector load + // M1 dim - for src vector store + using SrcDimAccessOrder = + std::conditional_t, Sequence<1, 0>>; + using DstDimAccessOrder = + std::conditional_t, Sequence<1, 0>>; + + using ThreadClusterLengths = + Sequence{}, Number{}>; + + auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2< + ThisThreadBlock, + ElementwiseOperation, + uniform_sequence_gen_t(InMemoryDataOperationEnum::Set)>, + Sequence, + ThreadClusterLengths, + ThreadClusterArrangeOrder, + decltype(src_datas), + decltype(dst_datas), + InGridDescTuple, + OutGridDescTuple, + SrcDimAccessOrder, + DstDimAccessOrder, + SrcVectorDim, + DstVectorDim, + InScalarPerVectorSeq, + OutScalarPerVectorSeq, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t, + uniform_sequence_gen_t>{in_grid_desc_tuple, + input_thread_grid_offset, + out_grid_desc_tuple, + output_thread_grid_offset, + elementwise_op}; + global_to_global_transfer.Run( + in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp new file mode 100755 index 000000000000..072275c089ce --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp @@ -0,0 +1,500 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// X = Elementwise(input1, input2, input3, ...) +// Y = Normalization(X, beta, gamma) +template +struct GridwiseElementwiseLayernormWelfordVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr index_t NumInput = InDataTypePointerTuple::Size(); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto XThreadBufferNumber = Number{}; + static constexpr auto GammaThreadBufferNumber = Number{}; + static constexpr auto BetaThreadBufferNumber = Number{}; + static constexpr auto YThreadBufferNumber = Number{}; + + __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, + int thread_k_cluster_id) + { + int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + + __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& beta_grid_desc_m_k, + const GridDesc_M_K& y_grid_desc_m_k, + index_t num_k_block_tile_iteration, + AccDataType epsilon, + const InDataTypePointerTuple p_in_global_tuple, + XDataType* const __restrict__ p_x_lds_, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + const XElementwiseOperation x_elementwise_op, + const YElementwiseOperation y_elementwise_op) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t grid_size = get_grid_size(); + + auto in_global_buf_tuple = generate_tuple( + [&](auto I) { + static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() == + 2); // matrix dimension + + return make_dynamic_buffer( + p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize()); + }, + Number{}); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto x_lds_val_buf = make_dynamic_buffer( + p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size); + + auto in_thread_buf_tuple = generate_tuple( + [&](auto) { + return generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + }, + Number{}); + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto beta_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + using ThreadBufferLengths_M_K = Sequence; + + constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto in_global_load_tuple = generate_tuple( + [&](auto I) { + using DataTypePointer = remove_cvref_t; + using DataType = remove_cv_t>; + + return ThreadwiseTensorSliceTransfer_v2{ + in_grid_2d_desc_tuple[I], + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)}; + }, + Number{}); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * BetaSrcVectorSize)); + + using PassThrough = tensor_operation::element_wise::PassThrough; + PassThrough pass_through_op; + auto threadwise_x_store = + ThreadwiseTensorSliceTransfer_v1r3( + x_grid_desc_m_k, + make_multi_index(thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize), + pass_through_op); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + // Copy x from Cache + // one pass: fwd, second pass: bwd + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, NumInput, 1>{}([&](auto I) { // input load loop + in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I], + in_global_buf_tuple[I], + thread_buffer_desc_m_k, + make_tuple(I0, I0), + in_thread_buf_tuple(iK0)(I)); + + in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I], + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // get reference to in data + const auto in_data_refs = generate_tie( + // return type should be lvalue + [&](auto I) -> const auto& { + return in_thread_buf_tuple(iK0)(I)(Number{}); + }, + Number{}); + + // get reference to dst data + auto out_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return x_thread_buf(iK0)(Number{}); }, + I1); + + unpack2(x_elementwise_op, out_data_refs, in_data_refs); + }); + }); + threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf); + + if constexpr(!SweepOnce) + { + threadwise_x_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(iK0), + x_grid_desc_m_k, + x_lds_val_buf); + threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + }); + + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k; + + if constexpr(!SweepOnce) + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + if constexpr(!SweepOnce) + { + static_for<0, XThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_lds_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } + + static_for<0, GammaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon); + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + divisor; + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, BetaThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, YThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + if constexpr(!SweepOnce) + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp new file mode 100755 index 000000000000..21dac6f9e9eb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -0,0 +1,1053 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_dequant.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_fpAintB_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const ScaleGridDesc scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_scale_grid, + p_c_grid, + p_shared, + a_grid_desc, + b_grid_desc, + scale_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_scale_grid; + ignore = p_c_grid; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = scale_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx11__)) +} + +// Assume B is Col-Major +template +struct GridwiseFpAintBGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // FIX ME: To be deprecated + static constexpr auto K1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / 2 / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else + constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + struct SharedMemTrait + { + // LDS allocation for A and Dequantized B: be careful of DataType + // scale would not put into LDS. + using LDS_ADataType = ADataType; + using LDS_BDataType = ADataType; + using LDS_CDataType = CShuffleDataType; + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + // B would be dequantize to ADataType before enter LDS + // b_lds_offset = LDS size allocated for a in byte / LDS_BDataType + static constexpr auto b_block_space_offset = + (a_block_space_offset + a_block_space_size_aligned) * sizeof(LDS_ADataType) / + sizeof(LDS_BDataType); + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(LDS_CDataType), + a_block_space_size_aligned * sizeof(LDS_ADataType) + + b_block_space_size_aligned * sizeof(LDS_BDataType)); + }; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + const ScaleDataType* __restrict__ p_scale_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const ScaleGridDesc& scale_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc.GetElementSpaceSize()); + const auto scale_grid_buf = make_dynamic_buffer( + p_scale_grid, scale_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1_dequant, +/* typename BlockScaleSliceLengths, */ Sequence, +/* typename ThreadClusterLengths, */ BBlockTransferThreadClusterLengths_K0_N_K1, +/* typename ThreadClusterArrangeOrder, */ BBlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ BDataType, +/* typename ScaleData, */ ScaleDataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(b_grid_desc), +/* typename ScaleDesc, */ decltype(scale_grid_desc), +/* typename DstDesc, */ decltype(b_block_desc), +/* typename SrcDimAccessOrder, */ BBlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ BBlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ BBlockTransferSrcScalarPerVector, +/* index_t ScaleScalarPerVector, */ 1, +/* index_t DstScalarPerVector, */ BBlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t ScaleScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ BThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + scale_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + ck::tensor_operation::element_wise::PassThrough{}, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; +/*******************************************************************************/ + // GEMM + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); + + // gridwise GEMM pipeline + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc, + b_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + scale_grid_desc, + scale_grid_buf, + blockwise_gemm, + c_thread_buf, + KBlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..f406bfb95aa8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,1014 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_bias_add_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const C1ElementwiseOperation c1_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_bias_grid, + p_d0_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + c1_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_bias_grid; + ignore = p_d0_grid; + ignore = p_reduces_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c1_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c1_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_bias_grid, + const FloatC1* __restrict__ p_d0_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const C1ElementwiseOperation& c1_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c0_grid_desc_mblock_mperblock_nblock_nperblock, + const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c1_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_bias_grid, c0_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_d0_grid, c1_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + ReduceGlobalMemoryDataOperation::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + // c0 and c1 + constexpr auto c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + constexpr auto c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock; + + auto c01_thread_buf = make_static_buffer( + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatReduceAcc, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + auto c1_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC1, + FloatReduceAcc, + decltype(c1_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + auto c_reduce_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatC, + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_buf, + c0_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = activation(c + bias) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + c_element_op(out, c_reduce_thread_buf(i) + c01_thread_buf(i)); + c_reduce_thread_buf(i) = out; + }); + + c1_thread_copy_global_to_vgpr.Run( + c1_grid_desc_mblock_mperblock_nblock_nperblock, + c1_grid_buf, + c1_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c01_thread_buf); + + // c = c + c1_functior(c1) + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c1_element_op(c01_thread_buf(i), c01_thread_buf(i)); + c_reduce_thread_buf(i) += c01_thread_buf(i); + }); + + c_reduce_thread_copy_vgpr_to_global.Run( + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c_reduce_thread_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C1 + c1_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c1_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } // Reduction + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp new file mode 100755 index 000000000000..5f6f2768ebc8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp @@ -0,0 +1,690 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemmDlMultipleD_km_kn_mn +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + b_grid_desc_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB)) + { + return false; + } + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) + { + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_grid_desc_k0_m0_m1_k1; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) + { + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_grid_desc_k0_n0_n1_k1; + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_m0_m10_m11_n0_n10_n11; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { return MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(ds_grid_desc_m_n[i]); }, + Number{}); + } + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); + } + + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared_block, + const AElementwiseOperation&, + const BElementwiseOperation&, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const DsGridDesc_M0_M10_M11_N0_N10_N11& ds_grid_desc_m0_m10_m11_n0_n10_n11, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_k0_m0_m1_k1, + make_multi_index(0, im0, 0, 0), + a_block_desc_k0_m0_m1_k1, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_k0_n0_n1_k1, + make_multi_index(0, in0, 0, 0), + b_block_desc_k0_n0_n1_k1, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = static_cast(p_shared_block); + FloatAB* p_b_block_double = + static_cast(p_shared_block) + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + + // Initialize C + c_thread_buf.Clear(); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + + block_sync_lds(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m0_m10_m11_n0_n10_n11[i].GetElementSpaceSize()); + }, + Number{}); + + auto ds_thread_buf = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return StaticBuffer{}; + }, + Number{}); + + auto ds_threadwise_copy = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return ThreadwiseTensorSliceTransfer_v2< + DDataType, + DDataType, + decltype(ds_grid_desc_m0_m10_m11_n0_n10_n11[i]), + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + Sequence{}>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + 1, + false>(ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])); + }, + Number{}); + + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I0], 1>{}([&](auto m10) { + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I1], 1>{}([&](auto m11) { + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I2], 1>{}([&](auto n10) { + // load d matrix data + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).Run(ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + ds_grid_buf[i], + c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + ds_thread_buf(i)); + }); + // cal element op + static_for<0, c_m10_m11_n10_n11_thread_tensor_lengths[I3], 1>{}( + [&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + return ds_thread_buf[iSrc][i]; + }, + Number{}); + + // get reference to dst data + constexpr index_t c_offset = + c_thread_desc_m0_m10_m11_n0_n10_n11.CalculateOffset( + make_tuple(0, m10, m11, 0, n10, i)); + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto) -> auto& { return c_thread_buf(Number{}); }, + Number<2>{}); + + unpack2(cde_element_op, dst_data_refs, src_data_refs); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).MoveSrcSliceWindow( + ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index(0, 0, 0, 0, 1, 0)); + }); + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).MoveSrcSliceWindow( + ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index( + 0, 0, 1, 0, -c_m10_m11_n10_n11_thread_tensor_lengths[I2], 0)); + }); + }); + static_for<0, NumDTensor, 1>{}([&](auto i) { + ds_threadwise_copy(i).MoveSrcSliceWindow( + ds_grid_desc_m0_m10_m11_n0_n10_n11[i], + make_multi_index( + 0, 1, -c_m10_m11_n10_n11_thread_tensor_lengths[I1], 0, 0, 0)); + }); + }); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp new file mode 100755 index 000000000000..562b9b8ffa3b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp @@ -0,0 +1,1128 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap block_2_ctile_map) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_grid_desc_k0_m0_m1_k1, + b_grid_desc_k0_n0_n1_k1, + c_grid_desc_m0_m10_m11_n0_n10_n11, + block_2_ctile_map, + integral_constant{}, + integral_constant{}); +} + +template +struct GridwiseGemmDl_km_kn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k_m = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k_n = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = + math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = + math::integer_least_multiple(b_block_desc_k_n.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && + K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2)) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1) + { + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_grid_desc_k0_m0_m1_k1 = + transform_tensor_descriptor(a_grid_desc_k0_m_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return a_grid_desc_k0_m0_m1_k1; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) + { + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_grid_desc_k0_n0_n1_k1 = + transform_tensor_descriptor(b_grid_desc_k0_n_k1, + make_tuple(make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return b_grid_desc_k0_n0_n1_k1; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_m0_m10_m11_n0_n10_n11; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N00_M01_N01( + c_grid_desc_m_n); + } + + using AGridDesc_K0_M0_M1_K1 = decltype(MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})); + using BGridDesc_K0_N0_N1_K1 = decltype(MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1, + const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const Block2CTileMap& block_2_ctile_map, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m0_m1_k1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n0_n1_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto c_m0_n0_block_cluster_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this forces index data into SGPR + const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]); + const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]); + + if(!block_2_ctile_map.ValidCTileIndex( + make_tuple(im0, in0), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(a_block_desc_k0_m0_m1_k1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_k0_m0_m1_k1, + make_multi_index(0, im0, 0, 0), + a_block_desc_k0_m0_m1_k1, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(b_block_desc_k0_n0_n1_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_k0_n0_n1_k1, + make_multi_index(0, in0, 0, 0), + b_block_desc_k0_n0_n1_k1, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + + // Initialize C + c_thread_buf.Clear(); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_grid_desc_k0_m0_m1_k1.GetLength(I0); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS double buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_even_buf); + + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_n0_n1_k1, b_block_slice_copy_step); + + block_sync_lds(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_k0_n0_n1_k1, b_global_buf); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_k0_n0_n1_k1, b_block_odd_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, + make_multi_index(im0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + in0, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); + } + } +}; + +template +struct GridwiseGemmDl_bkm_bkn_mn_v1r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_b_k0_m_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_b_k0_n_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_b_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_b_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB); + } + + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1, + const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_b_k0_m_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + b_grid_desc_b_k0_n_k1.GetElementSpaceSize() * sizeof(FloatAB) <= TwoGB && + c_grid_desc_m_n.GetElementSpaceSize() * sizeof(FloatC) <= TwoGB)) + { + return false; + } + + const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2); + const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2); + const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1); + const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0); + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + return (M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_b_k0_n_k1.GetLength(I1) && + K1 == a_grid_desc_b_k0_m_k1.GetLength(I3) && + K1 == b_grid_desc_b_k0_n_k1.GetLength(I3)) && + KBatch == b_grid_desc_b_k0_n_k1.GetLength(I0) && + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0); + } + + __host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N) + { + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0) + { + const bool has_main_k_block_loop = (K0 + K0PerBlock) / (2 * K0PerBlock) > 1; + + return has_main_k_block_loop; + } + + __host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0) + { + const bool has_double_tail_k_block_loop = (K0 / K0PerBlock) % 2 == 0; + + return has_double_tail_k_block_loop; + } + + __host__ __device__ static constexpr auto + MakeAGridDescriptor_B_K0_M0_M1_K1(const AGridDesc_B_K0_M_K1& a_grid_desc_b_k0_m_k1) + { + const auto KBatch = a_grid_desc_b_k0_m_k1.GetLength(I0); + const auto K0 = a_grid_desc_b_k0_m_k1.GetLength(I1); + const auto M = a_grid_desc_b_k0_m_k1.GetLength(I2); + + const auto M1 = Number{}; + const auto M0 = M / M1; + + const auto a_grid_desc_b_k0_m0_m1_k1 = transform_tensor_descriptor( + a_grid_desc_b_k0_m_k1, + make_tuple(make_pass_through_transform(KBatch), + make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(M0, M1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return a_grid_desc_b_k0_m0_m1_k1; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_B_K0_N0_N1_K1(const BGridDesc_B_K0_N_K1& b_grid_desc_b_k0_n_k1) + { + const auto KBatch = b_grid_desc_b_k0_n_k1.GetLength(I0); + const auto K0 = b_grid_desc_b_k0_n_k1.GetLength(I1); + const auto N = b_grid_desc_b_k0_n_k1.GetLength(I2); + + const auto N1 = Number{}; + const auto N0 = N / N1; + + const auto b_grid_desc_b_k0_n0_n1_k1 = transform_tensor_descriptor( + b_grid_desc_b_k0_n_k1, + make_tuple(make_pass_through_transform(KBatch), + make_pass_through_transform(K0), + make_unmerge_transform(make_tuple(N0, N1)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{})); + + return b_grid_desc_b_k0_n0_n1_k1; + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M11 = + Number{}; + constexpr auto N11 = + Number{}; + + constexpr auto M10 = M1 / M11; + constexpr auto N10 = N1 / N11; + + const auto c_grid_desc_m0_m10_m11_n0_n10_n11 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)), + make_unmerge_transform(make_tuple(N0, N10, N11))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_m0_m10_m11_n0_n10_n11; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CGridDesc_M_N& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N00_M01_N01( + c_m_n_grid_desc, M01, N01, KBatch); + } + + using AGridDesc_B_K0_M0_M1_K1 = + decltype(MakeAGridDescriptor_B_K0_M0_M1_K1(AGridDesc_B_K0_M_K1{})); + using BGridDesc_B_K0_N0_N1_K1 = + decltype(MakeBGridDescriptor_B_K0_N0_N1_K1(BGridDesc_B_K0_N_K1{})); + using CGridDesc_M0_M10_M11_N0_N10_N11 = + decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const AGridDesc_B_K0_M0_M1_K1& a_grid_desc_b_k0_m0_m1_k1, + const BGridDesc_B_K0_N0_N1_K1& b_grid_desc_b_k0_n0_n1_k1, + const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11, + const CBlockClusterAdaptor& c_block_cluster_adaptor, + integral_constant, + integral_constant) + { + const auto a_global_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_b_k0_m0_m1_k1.GetElementSpaceSize()); + const auto b_global_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_b_k0_n0_n1_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_m10_m11_n0_n10_n11.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I0), + c_grid_desc_m0_m10_m11_n0_n10_n11.GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + + // TODO: change this. I think it needs multi-dimensional alignment + constexpr auto max_lds_align = K1; + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(I1, Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(I1, Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, dst of blockwise copy + // be careful of LDS alignment + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, I1, Number{}, K1), max_lds_align); + + // TODO: check alignment + // A matrix in LDS memory, for blockwise GEMM + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + // TODO: check alignment + // B matrix in LDS memory, for blockwise GEMM + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + + static_assert(a_block_desc_k0_m0_m1_k1.GetElementSpaceSize() == + a_k0_m_k1_block_desc.GetElementSpaceSize() && + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize() == + b_k0_n_k1_block_desc.GetElementSpaceSize() && + "wrong!"); + + // A matrix blockwise copy + auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence<1, K0PerBlock, 1, MPerBlock, K1.value>, + ABlockTransferThreadSliceLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterLengths_K0_M0_M1_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(a_block_desc_b_k0_m0_m1_k1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths + ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths + ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(a_grid_desc_b_k0_m0_m1_k1, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0, 0), + a_block_desc_b_k0_m0_m1_k1, + make_multi_index(0, 0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1< + BlockSize, + InMemoryDataOperationEnum::Set, + Sequence<1, K0PerBlock, 1, NPerBlock, K1.value>, + BBlockTransferThreadSliceLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterLengths_K0_N0_N1_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + remove_reference_t, + decltype(b_block_desc_b_k0_n0_n1_k1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2, 3, 4>, + BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths + BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths + BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder + Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder + false, + true>(b_grid_desc_b_k0_n0_n1_k1, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0, 0), + b_block_desc_b_k0_n0_n1_k1, + make_multi_index(0, 0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlocl, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + const auto blockwise_gemm = + BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2< + BlockSize, + FloatAB, + FloatAB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + M1PerThreadM111, + N1PerThreadN111, + KPerThread, + M11N11ThreadClusterM110Xs, + M11N11ThreadClusterN110Xs, + M1PerThreadM111, + N1PerThreadN111>{}; + + constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths = + decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1(); + + constexpr auto c_thread_desc_m10_m11_n10_n11 = make_naive_tensor_descriptor_packed( + sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_aligned_space_size = math::integer_least_multiple( + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_aligned_space_size = math::integer_least_multiple( + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block_double = p_shared_block; + FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size; + + // register allocation for output + auto c_thread_buf = make_static_buffer( + c_thread_desc_m10_m11_n10_n11.GetElementSpaceSize()); + + // Initialize C + c_thread_buf.Clear(); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0, 0); + + auto a_block_even_buf = make_dynamic_buffer( + p_a_block_double, a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_even_buf = make_dynamic_buffer( + p_b_block_double, b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + auto a_block_odd_buf = make_dynamic_buffer( + p_a_block_double + a_block_aligned_space_size, + a_block_desc_k0_m0_m1_k1.GetElementSpaceSize()); + auto b_block_odd_buf = make_dynamic_buffer( + p_b_block_double + b_block_aligned_space_size, + b_block_desc_k0_n0_n1_k1.GetElementSpaceSize()); + + // LDS double buffer: preload data into LDS + { + a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); + + a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf); + } + + if constexpr(HasMainKBlockLoop) + { + const auto K0 = a_grid_desc_b_k0_m0_m1_k1.GetLength(I1); + + index_t k_block_data_begin = 0; + + // LDS double buffer: main body + // use Do-While loop instead of For loop to simplify control flow + do + { + // even iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS double buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run(c_thread_desc_m10_m11_n10_n11, + a_block_even_buf, + b_block_even_buf, + c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf); + + // odd iteration + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, + a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, + b_block_slice_copy_step); + + // LDS doubel buffer: load next data from device mem + a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on current data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + + // LDS double buffer: store next data to LDS + a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_even_buf); + b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_even_buf); + + k_block_data_begin += 2 * K0PerBlock; + } while(k_block_data_begin < K0 - 2 * K0PerBlock); + } + + // LDS double buffer: tail + if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left + { + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_b_k0_m0_m1_k1, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_b_k0_n0_n1_k1, b_block_slice_copy_step); + + block_sync_lds(); + + // LDS double buffer: load last data from device mem + a_blockwise_copy.RunRead(a_grid_desc_b_k0_m0_m1_k1, a_global_buf); + b_blockwise_copy.RunRead(b_grid_desc_b_k0_n0_n1_k1, b_global_buf); + + // LDS double buffer: GEMM on 2nd-last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + + // LDS double buffer: store last data to LDS + a_blockwise_copy.RunWrite(a_block_desc_b_k0_m0_m1_k1, a_block_odd_buf); + b_blockwise_copy.RunWrite(b_block_desc_b_k0_n0_n1_k1, b_block_odd_buf); + + block_sync_lds(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_odd_buf, b_block_odd_buf, c_thread_buf); + } + else // if has 1 iteration left + { + __syncthreads(); + + // LDS double buffer: GEMM on last data + blockwise_gemm.Run( + c_thread_desc_m10_m11_n10_n11, a_block_even_buf, b_block_even_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_m10_m11_n0_n10_n11 = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + const auto c_m10_m11_n10_n11_thread_origin_idx_on_block = + blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1( + get_thread_local_1d_id()); + + ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatC, + decltype(c_thread_desc_m0_m10_m11_n0_n10_n11), + decltype(c_grid_desc_m0_m10_m11_n0_n10_n11), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, + c_m10_m11_n10_n11_thread_tensor_lengths[I0], + c_m10_m11_n10_n11_thread_tensor_lengths[I1], + 1, + c_m10_m11_n10_n11_thread_tensor_lengths[I2], + c_m10_m11_n10_n11_thread_tensor_lengths[I3]>, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{c_grid_desc_m0_m10_m11_n0_n10_n11, + make_multi_index(m_block_data_idx_on_grid, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I0], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], + n_block_data_idx_on_grid, + c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], + c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]), + ck::tensor_operation::element_wise::PassThrough{}} + .Run(c_thread_desc_m0_m10_m11_n0_n10_n11, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_m10_m11_n0_n10_n11, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp new file mode 100755 index 000000000000..b473d7cbf2b7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp @@ -0,0 +1,701 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_dpp(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_ak0_m_ak1 = amd_wave_read_first_lane( + GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA)); + const auto b_grid_desc_bk0_n_bk1 = amd_wave_read_first_lane( + GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane( + GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + static constexpr auto max_lds_align = math::lcm(AK1, BK1); + + using ThisThreadBlock = ThisThreadBlock; + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) { return math::integer_divide_floor(K, AK1Value); } + __host__ static auto CalculateBK0(index_t K) { return math::integer_divide_floor(K, BK1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + AK0{CalculateAK0(K)}, + BK0{CalculateBK0(K)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t AK0; + index_t BK0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const ABDataType* p_a_grid_, + const ABDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ABDataType* p_a_grid; + const ABDataType* p_b_grid; + CDataType* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(ABDataType); + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "Wrong! AK1 must be known at the time of compilation."); + static_assert(is_known_at_compile_time>::value, + "Wrong! BK1 must be known at the time of compilation."); + + static_assert( + MPerBlock % (MPerDpp * MDppPerWave) == 0, + "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."); + static_assert( + NPerBlock % (NPerDpp * NDppPerWave) == 0, + "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."); + + static_assert( + KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."); + + static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! AK1Value must be divisible by " + "ABlockTransferDstScalarPerVector_K1"); + + static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0, + "Invalid tuning parameters! BK1Value must be divisible by " + "BBlockTransferDstScalarPerVector_K1"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if(problem.K % KPerBlock != 0) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = problem.K / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const auto num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + + using BlockwiseGemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + } + + static constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + __device__ static auto + MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_k = matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __device__ static auto + MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + const auto b_grid_desc_n_k = matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + return transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(BK0, BK1Value))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + } + + __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw); + } + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[AK0PerBlock, MPerBlock] is in LDS + // b_mtx[BK0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr index_t KPack = math::max( + math::lcm(AK1, BK1), DppSelector::selected_dpp.k_per_dpp); + auto blockwise_gemm = + BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(AK0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(BK0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(I0); + // (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock) + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 / AK0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + constexpr auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + constexpr auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2, + make_tuple(I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp new file mode 100755 index 000000000000..054aca293694 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -0,0 +1,1079 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +// GEMM: +// input : A0[M, K], A1[M, K] +// input : B0[N, K], B1[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleABD_xdl_cshuffle +{ + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_GFX90A_DENORM_WORKAROUND + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; +#else + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k) + { + return generate_tuple( + [&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); }, + Number{}); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k) + { + return generate_tuple( + [&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); }, + Number{}); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k, + const BsGridDesc_N_K& bs_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + const auto M = as_grid_desc_m_k[I0].GetLength(I0); + const auto N = bs_grid_desc_n_k[I0].GetLength(I0); + const auto AK = as_grid_desc_m_k[I0].GetLength(I1); + const auto BK = bs_grid_desc_n_k[I0].GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + bool valid = true; + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType = remove_cvref_t>; + valid = + valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB); + valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) && + AK == as_grid_desc_m_k[i].GetLength(I1)); + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType = remove_cvref_t>; + valid = + valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB); + valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) && + BK == bs_grid_desc_n_k[i].GetLength(I1)); + }); + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = AK / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + + if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto MakeAsGridDescriptor_M_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& KRaws, + const ck::Array& AsStride +#else + const std::array& MRaws, + const std::array& KRaws, + const std::array& AsStride +#endif + ) + { + return generate_tuple( + [&](auto i) { + using ALayout = remove_cvref_t>; + + return MakeAGridDescriptor_M_K(MRaws[i], KRaws[i], AsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto MakeBsGridDescriptor_N_K( +#ifdef CK_CODE_GEN_RTC + const ck::Array& NRaws, + const ck::Array& KRaws, + const ck::Array& BsStride +#else + const std::array& NRaws, + const std::array& KRaws, + const std::array& BsStride +#endif + ) + { + return generate_tuple( + [&](auto i) { + using BLayout = remove_cvref_t>; + + return MakeBGridDescriptor_N_K(NRaws[i], KRaws[i], BsStride[i]); + }, + Number{}); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto MakeDsGridDescriptor_M_N( +#ifdef CK_CODE_GEN_RTC + const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride +#else + const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride +#endif + ) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, + const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || + is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + gridwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_buf, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run(AsGridPointer p_as_grid, + BsGridPointer p_bs_grid, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, +#ifdef CK_CODE_GEN_RTC + const ck::Array StrideAs, + const ck::Array StrideBs, + const ck::Array StrideDs, +#else + const std::array StrideAs, + const std::array StrideBs, + const std::array StrideDs, +#endif + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + using AsGridDesc_M_K = + remove_cvref_t({}, {}, {}))>; + using BsGridDesc_N_K = + remove_cvref_t({}, {}, {}))>; + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + AsGridDesc_M_K as_grid_desc_m_k; + BsGridDesc_N_K bs_grid_desc_n_k; + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumATensor, 1>{}([&](auto j) { + using ALayout = remove_cvref_t>; + + as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K(M, K, StrideAs[j]); + }); + + static_for<0, NumBTensor, 1>{}([&](auto j) { + using BLayout = remove_cvref_t>; + + bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K(N, K, StrideBs[j]); + }); + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k); + + const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + as_grid_desc_ak0_m_ak1, + bs_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp new file mode 100755 index 000000000000..127d88957244 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp @@ -0,0 +1,958 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +template +struct GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr index_t NumRTensor = RsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + template + static constexpr auto MakeTsGridPointer() + { + return generate_tuple( + [&](auto i) { + using T = remove_cvref_t>; + if constexpr(isConst) + return static_cast(nullptr); + else + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // A desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const EGridDesc_M_N& e_grid_desc_m_n, + const RGridDesc_M& r_grid_desc_m, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(AGridDesc_M_K::GetNumOfDimension() == 2); + static_assert(BGridDesc_N_K::GetNumOfDimension() == 2); + static_assert(EGridDesc_M_N::GetNumOfDimension() == 2); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + if(M != r_grid_desc_m.GetLength(I0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeRGridDescriptor_MBlock_MPerBlock(const RGridDesc_M& r_grid_desc_m) + { + const auto M = r_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto r_grid_desc_mblock_mperblock = transform_tensor_descriptor( + r_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return r_grid_desc_mblock_mperblock; + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using DefaultAGridDesc_AK0_M_AK1 = + remove_cvref_t; + using DefaultBGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + // Support 2 dimension in the future. Not only M + using RGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeTsGridPointer()); + using RsGridPointer = decltype(MakeTsGridPointer()); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + FloatE* __restrict__ p_e_grid, + RsGridPointer p_rs_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const QsElementwiseOperation& qs_element_op, + const RsElementwiseOperation& rs_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const StaticallyIndexedArray& + ds_grid_desc_mblock_mperblock_nblock_nperblock, // FIXME: Ds desc may be of different + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const StaticallyIndexedArray& + rs_grid_desc_mblock_mperblock, // FIXME: Rs desc may be of different + const Block2ETileMap& block_2_etile_map) + { + // FIXME - Share code with other gemm kernel + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + auto rs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_rs_grid(i), rs_grid_desc_mblock_mperblock[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + Ds + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_der_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) * + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR cde_reduce_thread_desc_mperblock_nperblock + constexpr auto cde_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto r_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + constexpr auto r_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto e_thread_buf = make_static_buffer( + cde_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CDRThreadTransferClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + // To apply D0, D1, ... and reduction. + // Copy c shuffle from LDS back to VGPR + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(cde_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CDEReduceThreadTransferScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + // Copy result of reduction back from VGPR to global + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_r_grid = p_rs_grid[I]; + auto r_element_op = rs_element_op[I]; + auto r_grid_desc_mblock_mperblock = rs_grid_desc_mblock_mperblock[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(r_thread_desc_mblock_mperblock), + decltype(r_grid_desc_mblock_mperblock), + decltype(r_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + RThreadTransferDstScalarPerVector_MPerBlock, + RsGlobalMemoryDataOperation::At(I), + 1, + false>{r_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + r_element_op}; + }, + Number{}); + + // D0, D1, ..., Dn + constexpr auto cde_reduce_thread_desc_I1_mperblock_I1_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // FIXME: Decrease usage of VGPR + // Apply pointwise lambda function from multi-source (Global and LDS) into VGPR + auto ds_thread_buf = generate_tuple( + [&](auto) { + return make_static_buffer( + cde_reduce_thread_desc_I1_mperblock_I1_nperblock.GetElementSpaceSize()); + }, + Number{}); + + // Copy D0, D1, ..., Dn from global to VGPR + auto ds_thread_copy_global_to_vgpr = generate_tuple( + [&](auto I) { + using DDataType = remove_cvref_t>; + return ThreadwiseTensorSliceTransfer_v2< + DDataType, + FloatReduceAcc, + decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock[I]), + decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CDEReduceThreadTransferScalarPerVector_NPerBlock, + 1, + true>(ds_grid_desc_mblock_mperblock_nblock_nperblock[I], + make_multi_index( + I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1])); + }, + Number{}); + + auto e_thread_copy_vgpr_to_global = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatE, + decltype(cde_reduce_thread_desc_I1_mperblock_I1_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + tensor_operation::element_wise::PassThrough, + Sequence, // SliceLengths + Sequence<0, 1, 2, 3>, // DimAccessOrder + 3, // DstVectorDim + CDEReduceThreadTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{ + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(I0, + m_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I0], + I0, + n_block_data_idx_on_grid + c_reduce_thread_data_idx_begin[I1]), + tensor_operation::element_wise::PassThrough{}}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_der_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to read from LDS + block_sync_lds(); + + // each thread shuffle data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to write to LDS + block_sync_lds(); + + // Get shuffle data from LDS to VGPR + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + cde_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + e_thread_buf); + + // Global read D0, D1, ... + static_for<0, NumDTensor, 1>{}([&](auto Id) { + auto& d_thread_copy_global_to_vgpr = ds_thread_copy_global_to_vgpr(Id); + d_thread_copy_global_to_vgpr.Run( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], + ds_grid_buf[Id], + cde_reduce_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + ds_thread_buf(Id)); + + if constexpr(access_id < num_access - 1) + { + // move on D0, D1, ... + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + d_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + ds_grid_desc_mblock_mperblock_nblock_nperblock[Id], de_global_step); + } + }); + + // cde_element_op(e, c, d0, d1, ...); + static_for<0, cde_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + const auto c_ds_src_data_refs = concat_tuple_of_reference( + tie(e_thread_buf[i]), + generate_tie( + [&](auto Id) -> const auto& { return ds_thread_buf[Id][i]; }, + Number{})); + auto e_dst_data_refs = tie(e_thread_buf(i)); + unpack2(cde_element_op, e_dst_data_refs, c_ds_src_data_refs); + }); + + // Global write E + e_thread_copy_vgpr_to_global.Run(cde_reduce_thread_desc_I1_mperblock_I1_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + if constexpr(access_id < num_access - 1) + { + // move on E + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + e_thread_copy_vgpr_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, de_global_step); + } + + // reduction + static_for<0, NumRTensor, 1>{}([&](auto Ir) { + auto r_thread_buf = make_static_buffer( + r_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(Ir); + + using ThreadReduceOperation = + remove_cvref_t; + + using ThreadwiseReduce = + ThreadwiseReduction; + + // threadwise reduction + const auto reduce_identityVal = + ThreadReduceOperation::template GetIdentityValue(); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { r_thread_buf(I) = reduce_identityVal; }); + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + qs_element_op[Ir](e_thread_buf(offset), e_thread_buf(offset)); + }); + }); + ThreadwiseReduce::Reduce(e_thread_buf, r_thread_buf); + + // gridwise reduction + reduce_thread_copy_vgpr_to_global.Run(r_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + r_thread_buf, + rs_grid_desc_mblock_mperblock[Ir], + rs_grid_buf(Ir)); + + if constexpr(access_id < num_access - 1) + { + // move on R0, R1, ... + constexpr auto de_global_step = sfc_der_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + rs_grid_desc_mblock_mperblock[Ir], + make_tuple(de_global_step[I0], de_global_step[I1])); + } + }); + }); // copy c, d, e + reduction + + } // shuffle C + Ds + reduction + write out + } // Run +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp new file mode 100755 index 000000000000..de6c9c1601ae --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -0,0 +1,1376 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_conv_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const index_t batch_count, + const AGridDesc_AK0_M_AK1 a_grid_desc, + const BGridDesc_BK0_N_BK1 b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock_, + const Block2CTileMap block_2_ctile_map, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + // offset base pointer for each work-group + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; + + DsPointer p_ds_grid_grp; + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_grid_desc, + b_grid_desc, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + a_element_op, + b_element_op, + cde_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock_; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = compute_ptr_offset_of_batch; + ignore = block_2_ctile_map; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_contraction_multiple_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const index_t batch_count, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const Block2CTileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + // printf("entry kernel launch"); + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); + + const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); + + static constexpr index_t NumDTensor = + DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock::Size(); + + DsPointer p_ds_grid_grp; + + static_for<0, NumDTensor, 1>{}( + [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); + + GridwiseOp::template Run(p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_grid_desc, + b_grid_desc, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = batch_count; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_mupltipe_d_wmma_cshuffle( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseOp::SharedMemTrait::lds_size]; + + GridwiseOp::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_grid_desc, + b_grid_desc, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + cde_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx11__ )) +} + +template < // DataType Family + typename ADataType, + typename BDataType, + typename AccDataType, + typename CShuffleDataType, + typename DsDataType, + typename EDataType, + // InMemory Data Descriptor + typename AGridDesc, + typename BGridDesc, + typename DsGridDesc_M_N, + typename EGridDesc_M_N, + // ElementwiseOp Family + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CDEElementwiseOperation, + InMemoryDataOperationEnum EGlobalMemoryDataOperation, + // Tiling Family + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t MPerWmma, + index_t NPerWmma, + index_t K1Value, + index_t MRepeat, + index_t NRepeat, + // ThreadCluster Family + index_t BlockSize, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool AEnableLds, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BEnableLds, + bool BBlockLdsExtraN, + index_t CShuffleMRepeatPerShuffle, + index_t CShuffleNRepeatPerShuffle, + typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, + index_t NumGemmKPrefetchStage = 1, + LoopScheduler LoopSched = make_default_loop_scheduler(), + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemmMultipleD_Wmma +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else + constexpr auto A_KRow = I1; +#endif + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + // CheckValidity for kernels without multi D + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + printf("GridwiseOp: D descriptor dimension check failure\n"); + return false; + } + + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N_& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N_& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const EGridDesc_M_N& e_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc.GetElementSpaceSize()); + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in LDS, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc.GetElementSpaceSize()); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::a_block_space_size_aligned, + b_block_desc.GetElementSpaceSize()); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; +/*******************************************************************************/ + // GEMM + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); + + // gridwise GEMM pipeline + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc, + b_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + KBlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // This API Provide All dimension (size) you need + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat.GetElementSpaceSize()); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor buffers + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // shuffle: blockwise copy C from LDS to global + auto cde_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, // ThreadGroup + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, // ElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // bool ThreadTransferSrcResetCoordinateAfterRun, + Sequence> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_cde_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_shuffle_block_copy_lds_to_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id); + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_shuffle_block_copy_lds_to_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_global_step); + }); + + // move on E + cde_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..acbccf188973 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1126 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + static_assert(!DoElementwiseBeforeCShuffle || NumDTensor == 0); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_GFX90A_DENORM_WORKAROUND + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; + using BComputeDataType = + conditional_t, ck::bhalf_t, BComputeDataType_>; +#else + using AComputeDataType = AComputeDataType_; + using BComputeDataType = BComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) + + b_block_space_size_aligned * sizeof(BComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static bool + CheckTensorTransfersValidity(index_t MRaw, index_t NRaw, index_t KRaw) + { + // Check if the vector dim is K1 or M|N + const auto A_vector_dim_size = ABlockTransferSrcVectorDim == 2 ? KRaw : MRaw; + const auto B_vector_dim_size = BBlockTransferSrcVectorDim == 2 ? KRaw : NRaw; + const auto E_vector_dim_size = NRaw; + + // check vector load for A tensor + if constexpr(is_same_v) + { + if(!(A_vector_dim_size == KRaw && + A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(A_vector_dim_size == MRaw && + A_vector_dim_size % ABlockTransferSrcScalarPerVector == 0)) + return false; + } + else + { + return false; + } + + if constexpr(is_same_v) + { + if(!(B_vector_dim_size == NRaw && + B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(B_vector_dim_size == KRaw && + B_vector_dim_size % BBlockTransferSrcScalarPerVector == 0)) + return false; + } + else + { + return false; + } + + if constexpr(is_same_v) + { + if(!(E_vector_dim_size == NRaw && + E_vector_dim_size % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0)) + return false; + } + else if constexpr(is_same_v) + { + if(!(E_vector_dim_size == NRaw && + CDEShuffleBlockTransferScalarPerVector_NPerBlock == 1)) + return false; + } + else + { + return false; + } + + return true; + } + + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + [[maybe_unused]] const Block2ETileMap&, + index_t k_batch = 1) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto AK = a_grid_desc_m_k.GetLength(I1); + const auto BK = b_grid_desc_n_k.GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = AK / (KPerBlock * k_batch); + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + // if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + //{ + // return false; + //} + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K, + index_t k_batch = 1) + { + const index_t num_loop = K / (KPerBlock * k_batch); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const ck::Array& MRaws, + const ck::Array& NRaws, + const ck::Array& DsStride) +#else + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) +#endif + + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map, + const index_t k_batch = 1, + const index_t k_idx = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t num_k_per_block = + __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + AComputeDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(num_k_per_block * k_idx, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BComputeDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(num_k_per_block * k_idx, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || + is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * k_batch)); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return cde_element_op; + } + else + { + return pass_through; + } + }; + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + conditional_t, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + lds_to_global_element_op()}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) + const ck::Array StrideDs, +#else + const std::array StrideDs, +#endif + const index_t StrideE, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for problem definiton + const auto a_grid_desc_m_k = MakeAGridDescriptor_M_K(M, K, StrideA); + const auto b_grid_desc_n_k = MakeBGridDescriptor_N_K(K, N, StrideB); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_MK& a_grid_desc_m_k, + const BGridDesc_NK& b_grid_desc_n_k, + const DsGridDesc_MN& ds_grid_desc_m_n, + const EGridDesc_MN& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_ak0_m_ak1 = MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k); + const auto b_grid_desc_bk0_n_bk1 = MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp new file mode 100755 index 000000000000..1e79d67f9380 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -0,0 +1,1035 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load( + const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation cde_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap block_2_etile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_ds_grid; + ignore = p_e_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = cde_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = e_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = block_2_etile_map; +#endif +} + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + +#if CK_GFX90A_DENORM_WORKAROUND + using AComputeDataType = + conditional_t, ck::bhalf_t, AComputeDataType_>; +#else + using AComputeDataType = AComputeDataType_; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + if constexpr(is_same_v) + { + // FIXME: our support to non-K contiguous layout is limited, only work in some specific + // setting + return make_naive_tensor_descriptor_packed( + make_tuple(AK0PerBlock, Number{}, AK1)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(AK1, Number{}, I1)); + } + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + if constexpr(is_same_v) + { + // FIXME: our support to non-K contiguous layout is limited, only work in some specific + // setting + return make_naive_tensor_descriptor_packed( + make_tuple(BK0PerBlock, Number{}, BK1)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(BK1, Number{}, I1)); + } + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment. + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle. + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max( + NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) + + NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto + MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw), + make_tuple(I1, StrideA)); + } + }(); + + return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw); + } + + __host__ __device__ static auto + MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw), + make_tuple(StrideB, I1)); + } + }(); + + return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw); + } + + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); }, + Number{}); + } + + using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1)); + using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1)); + using DsGridDesc_M_N = remove_cvref_t; + using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1)); + + // A desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy. + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy. + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + using AGridDesc_AK0_M_AK1 = + remove_cvref_t; + using BGridDesc_BK0_N_BK1 = + remove_cvref_t; + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using Block2ETileMap = remove_cvref_t; + + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0, + "KPerBlock must be divisible by AK1Value and BK1Value!"); + + static_assert( + std::is_same_v && + std::is_same_v, + "Direct load transfers do not support elementwise operations other than passthrough."); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto AK = a_grid_desc_m_k.GetLength(I1); + const auto BK = b_grid_desc_n_k.GetLength(I1); + + // Check the consistency of descriptors. + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK)) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // Check the tile size. + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0)) + { + return false; + } + + // Check gridwise gemm pipeline. + const auto num_k_loop = AK / KPerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // Check block-to-E-tile. + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // Check tensor size: cannot exceed 2GB. + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // Divide block work by [M, N]. + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + block_work_idx, + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // This forces m/n_block_data_idx_on_grid into SGPR. + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, destination of blockwise copy. + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, destination of blockwise copy. + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferSrcAccessOrder, + ADataType, + AComputeDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferSrcAccessOrder, + BDataType, + BComputeDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || + is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment. + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared, + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // Shuffle C and write out. + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // Calculate the origin of thread output tensor on global memory. + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // Shuffle: threadwise copy C from VGPR to LDS. + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // A tuple of reference to C/Ds tensor descriptors. + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // A tuple of reference to C/Ds grid buffers. + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // A tuple of starting index of C/Ds blockwise copy. + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + // Blockwise copy C/D/E between LDS and global. + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + cde_element_op}; + + // Space filling curve for threadwise C in VGPR before shuffle. + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // Space filling curve for shuffled blockwise C/D/E. + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // Make sure it's safe to write to LDS. + block_sync_lds(); + + // Each thread write its data from VGPR to LDS. + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // Make sure it's safe to read from LDS. + block_sync_lds(); + + // Each block copy its data from LDS to global. + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // Move on Ds. + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // Move on E. + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + struct Argument : public tensor_operation::device::BaseArgument + { + Argument(const void* p_a_grid, + const void* p_b_grid, + std::array p_ds_grid, + void* p_e_grid, + index_t MRaw, + index_t NRaw, + index_t KRaw, + index_t StrideA, + index_t StrideB, + std::array StrideDs, + index_t StrideE, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) + : p_a_grid_{static_cast(p_a_grid)}, + p_b_grid_{static_cast(p_b_grid)}, + p_ds_grid_{}, + p_e_grid_{static_cast(p_e_grid)}, + a_grid_desc_m_k_{MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)}, + b_grid_desc_n_k_{MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)}, + ds_grid_desc_m_n_{}, + e_grid_desc_m_n_{MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)}, + a_grid_desc_ak0_m_ak1_{MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)}, + b_grid_desc_bk0_n_bk1_{MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)}, + ds_grid_desc_mblock_mperblock_nblock_nperblock_{}, + e_grid_desc_mblock_mperblock_nblock_nperblock_{}, + block_2_etile_map_{MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + MRaw_{MRaw}, + NRaw_{NRaw}, + KRaw_{KRaw} + { + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + p_ds_grid_(i) = static_cast(p_ds_grid[i]); + ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]); + }); + + if(CheckValidity(a_grid_desc_m_k_, + b_grid_desc_n_k_, + ds_grid_desc_m_n_, + e_grid_desc_m_n_, + block_2_etile_map_)) + { + ds_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_); + + e_grid_desc_mblock_mperblock_nblock_nperblock_ = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_); + } + } + + void Print() const + { + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; + std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; + static_for<0, NumDTensor, 1>{}( + [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); + std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + } + + // Pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + DsGridPointer p_ds_grid_; + EDataType* p_e_grid_; + + // Tensor descriptors for problem definiton + AGridDesc_M_K a_grid_desc_m_k_; + BGridDesc_N_K b_grid_desc_n_k_; + DsGridDesc_M_N ds_grid_desc_m_n_; + EGridDesc_M_N e_grid_desc_m_n_; + + // Tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + ds_grid_desc_mblock_mperblock_nblock_nperblock_; + EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + + // block-to-e-tile map + Block2ETileMap block_2_etile_map_; + + // element-wise ops + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation cde_element_op_; + + // For checking vector load/store + index_t MRaw_; + index_t NRaw_; + index_t KRaw_; + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp new file mode 100755 index 000000000000..5815eb5b0bc3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp @@ -0,0 +1,1186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ALDSType) + + b_block_space_size_aligned * sizeof(BLDSType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ALDSType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BLDSType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ALDSType, + BLDSType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched, + AComputeType, + BComputeType>(); + + if constexpr(Zeroing) + { + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __builtin_amdgcn_s_barrier(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } + } + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if constexpr(Zeroing) + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + __builtin_amdgcn_s_barrier(); + } + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if constexpr(Zeroing) + { + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + } + + template + __device__ static void RunWithZeroing(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run, true>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t*, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + nullptr, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp new file mode 100755 index 000000000000..f8de0a48e5ab --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#include +#endif + +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp" + +namespace ck { + +enum struct PipelineVersion +{ + v1, + v2, + // v3 is only used in the Stream-K implementation. + v4, + weight_only, +}; + +template +constexpr auto GridwiseGemmPipeline_Selector() +{ + if constexpr(PipelineVer == PipelineVersion::v1) + { + if constexpr(LoopSched == LoopScheduler::Default) + { + return GridwiseGemmPipeline_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return GridwiseGemmPipelineInterwave_v1{}; + } + } + else if constexpr(PipelineVer == PipelineVersion::v2) + { + return GridwiseGemmPipeline_v2{}; + } + else if constexpr(PipelineVer == PipelineVersion::v4) + { + return GridwiseGemmPipeline_v4{}; + } + else if constexpr(PipelineVer == PipelineVersion::weight_only) + { + return GridwiseGemmPipeline_v1_WeightOnly{}; + } + else + { +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) + std::cerr << "GridwiseGemmPipeline configuration is not available" << std::endl; +#endif + } +} + +} // namespace ck + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +inline std::ostream& operator<<(std::ostream& os, const ck::PipelineVersion& p) +{ + switch(p) + { + case ck::PipelineVersion::v1: os << "PipelineVersion::v1"; break; + case ck::PipelineVersion::v2: os << "PipelineVersion::v2"; break; + case ck::PipelineVersion::v4: os << "PipelineVersion::v4"; break; + case ck::PipelineVersion::weight_only: os << "PipelineVersion::weight_only"; break; + default: os << ""; + } + return os; +} +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp new file mode 100755 index 000000000000..0cdb7ce2ca01 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp @@ -0,0 +1,770 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +struct GridwiseGemmPipeline_v1; + +// 1-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<1, true, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// 2-stage prefetch +template <> +struct GridwiseGemmPipeline_v1<2, true, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + { + // Read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Read i+2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Move + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Write i+1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Read i+3 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm i+1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + // Write num_loop - 2 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // Sync + block_sync_lds(); + + // Write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I1); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I1); + + // Sync + block_sync_lds(); + + // Gemm num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, false, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_block_buf = a_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, true, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template <> +struct GridwiseGemmPipeline_v1<1, false, false> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + constexpr auto a_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + auto b_block_buf_switch = b_block_buf; + auto a_block_buf_switch = a_block_buf; + + // preload data into LDS + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.Run( + a_grid_desc, a_grid_buf, a_block_desc, a_block_origin_idx, a_block_buf_switch); + b_blockwise_copy.Run( + b_grid_desc, b_grid_buf, b_block_desc, b_block_origin_idx, b_block_buf_switch); + + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_block_buf = a_block_buf_switch; + b_block_buf = b_block_buf_switch; + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + } + } +}; + +template +struct GridwiseGemmPipeline_v1_WeightOnly; + +template <> +struct GridwiseGemmPipeline_v1_WeightOnly<1, true, true> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const ScaleGridDesc& scale_grid_desc, + const ScaleGridBuffer& scale_grid_buf, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // Global Prefetch Stage 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + // Scale read once + b_blockwise_copy.RunScaleRead(scale_grid_desc, scale_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // Dequantization fused in blockwise_copy + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +template +struct GridwiseGemmPipelineInterwave_v1; + +template <> +struct GridwiseGemmPipelineInterwave_v1<1> +{ + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + static __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // preload data into LDS + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + // block_sync_lds(); // moved into blockwise_gemm + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// Note: 2 stage prefetch not optimized for inter-wave loop scheduler +template <> +struct GridwiseGemmPipelineInterwave_v1<2> : public GridwiseGemmPipeline_v1<2, true, true> +{ +}; + +// TODO: deprecate as GridwiseGemmPipeline_Selector covers the functionality +template +constexpr auto GridwiseGemmPipeline_v1_Selector() +{ + if constexpr(LoopSched == LoopScheduler::Default) + { + return GridwiseGemmPipeline_v1{}; + } + else if constexpr(LoopSched == LoopScheduler::Interwave) + { + return GridwiseGemmPipelineInterwave_v1{}; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp new file mode 100755 index 000000000000..25e1cebdbe06 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +struct GridwiseGemmPipeline_v2 +{ + __host__ __device__ static constexpr bool IsSupported(const index_t num_loop) + { + // TODO: improve applicability + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to 1 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global Read 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write 0 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global Read 1 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { +#if CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT + __builtin_amdgcn_iglp_opt(CK_EXPERIMENTAL_PIPELINE_V2_IGLP_OPT); +#endif + + block_sync_lds(); + + // GEMM i + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // move to i + 2 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write i + 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + // global read i + 2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + + // LDS write i + 1 + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + // global read i + 2 + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + ++i; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 2 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + // LDS write num_loop - 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + block_sync_lds(); + + // GEMM num_loop - 1 + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp new file mode 100755 index 000000000000..ced62241cd29 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +struct GridwiseGemmPipeline_v3 +{ + __host__ __device__ static constexpr bool IsSupported(index_t) + { + // TODO: improve applicability + return true; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + num_loop--; + + while(num_loop > 0) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + block_sync_lds(); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + num_loop--; + } + // tail + { + block_sync_lds(); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp new file mode 100755 index 000000000000..08d986d0da62 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v4_direct_load.hpp @@ -0,0 +1,238 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/loop_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace lds_direct_load { + +__device__ void sched_barrier() +{ +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + // When direct loads and `waitcnt` instructions are submitted using inline asm, the usage of + // `sched_barrier` is necessary to make sure no instructions that use the loaded memory + // are scheduled by the compiler before the `waitcnt` instruction. + __builtin_amdgcn_sched_barrier(0); +#endif +} + +} // namespace lds_direct_load + +namespace ck { + +template +struct GridwiseGemmPipeline_v4; + +// 1-stage prefetch +template <> +struct GridwiseGemmPipeline_v4<1> +{ + static constexpr auto I0 = Number<0>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 1 && BBlockBuffers::Size() == 1); + auto& a_block_buf = a_block_bufs.At(I0); + auto& b_block_buf = b_block_bufs.At(I0); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +// 2-stages prefetch +template <> +struct GridwiseGemmPipeline_v4<2> +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + { + return num_loop % 2 == 0; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return (num_loop / 2) > 1; + } + + template + __device__ static void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffers& a_block_bufs, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffers& b_block_bufs, + const BBlockTransferStep& b_block_copy_step, + const BlockwiseGemm& blockwise_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + static_assert(ABlockBuffers::Size() == 2 && BBlockBuffers::Size() == 2); + auto& a_block_buf1 = a_block_bufs.At(I0); + auto& a_block_buf2 = a_block_bufs.At(I1); + auto& b_block_buf1 = b_block_bufs.At(I0); + auto& b_block_buf2 = b_block_bufs.At(I1); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf1); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf1); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + { + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf2); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf2); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + blockwise_gemm.Run(a_block_buf1, b_block_buf1, c_thread_buf); + + block_sync_lds_direct_load(); + lds_direct_load::sched_barrier(); + + blockwise_gemm.Run(a_block_buf2, b_block_buf2, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..db227bb7ef6e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp @@ -0,0 +1,894 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_reduce_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const ReduceInElementwiseOperations reduce_in_element_ops, + const ReduceAccElementwiseOperations reduce_out_element_ops, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_reduces_grid, + p_shared, + a_element_op, + b_element_op, + c_element_op, + reduce_in_element_ops, + reduce_out_element_ops, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + reduce_grid_desc_mblock_mperblock, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_reduces_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = reduce_in_element_ops; + ignore = reduce_out_element_ops; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = reduce_grid_desc_mblock_mperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M& d_grid_desc_m) + { + const auto M = d_grid_desc_m.GetLength(I0); + const auto MBlock = M / MPerBlock; + + const auto reduce_grid_desc_mblock_mperblock = transform_tensor_descriptor( + d_grid_desc_m, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return reduce_grid_desc_mblock_mperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using ReduceGridDescriptor_MBlock_MPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + ReducePtrsGlobal p_reduces_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const ReduceInElementwiseOperations& reduce_in_element_ops, + const ReduceAccElementwiseOperations& reduce_out_element_ops, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const ReduceGridDescriptor_MBlock_MPerBlock& reduce_grid_desc_mblock_mperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C + reduction + write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + // TODO: this should be implemented as a blockwise reduction + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // VGPR reduce_thread_desc_mperblock + constexpr auto reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // VGPR reduce_thread_desc_mblock_mperblock + constexpr auto reduce_thread_desc_mblock_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(I1, Number{})); + + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto reduce_tuple_thread_copy_vgpr_to_global = generate_tuple( + [&](auto I) { + auto p_reduce_grid = p_reduces_grid[I]; + auto reduce_acc_element_op = reduce_out_element_ops[I]; + + return ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + remove_pointer_t, + decltype(reduce_thread_desc_mblock_mperblock), + decltype(reduce_grid_desc_mblock_mperblock), + decltype(reduce_acc_element_op), + Sequence<1, mreduce_per_thread>, + Sequence<0, 1>, + 1, + CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, + ReduceGlobalMemoryDataOperation::At(I), + 1, + false>{reduce_grid_desc_mblock_mperblock, + make_multi_index(block_work_idx[I0], // mblock + c_reduce_thread_data_idx_begin[I0]), // mperblock + reduce_acc_element_op}; + }, + Number{}); + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // TODO - extract following into reduction_blockwise + { + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + static_for<0, p_reduces_grid.Size(), 1>{}([&](auto In) { + auto& p_reduce_grid = p_reduces_grid[In]; + + auto reduce_grid_buf = make_dynamic_buffer( + p_reduce_grid, reduce_grid_desc_mblock_mperblock.GetElementSpaceSize()); + + auto reduce_thread_buf = + make_static_buffer( + reduce_thread_desc_mperblock.GetElementSpaceSize()); + + auto& reduce_in_element_op = reduce_in_element_ops[In]; + + auto& reduce_thread_copy_vgpr_to_global = + reduce_tuple_thread_copy_vgpr_to_global(In); + + using ReduceOperation = remove_cvref_t; + using ThreadwiseReduce = + ThreadwiseReduction; + + // Global write Gemm shuffle + reduction + const auto reduce_identityVal = + ReduceOperation::template GetIdentityValue(); + + static_for<0, mreduce_per_thread, 1>{}( + [&](auto I) { reduce_thread_buf(I) = reduce_identityVal; }); + + // reduce in VGPR + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto offset = + Number{}; + + reduce_in_element_op(c_reduce_thread_buf(offset), + c_reduce_thread_buf(offset)); + }); + }); + + ThreadwiseReduce::Reduce(c_reduce_thread_buf, reduce_thread_buf); + + // copy from VGPR to Global + reduce_thread_copy_vgpr_to_global.Run(reduce_thread_desc_mblock_mperblock, + make_tuple(I0, I0), + reduce_thread_buf, + reduce_grid_desc_mblock_mperblock, + reduce_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + reduce_thread_copy_vgpr_to_global.MoveDstSliceWindow( + reduce_grid_desc_mblock_mperblock, + make_tuple(c_global_step[I0], c_global_step[I1])); + } + }); + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + + // Reduction + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp new file mode 100755 index 000000000000..70301c326a86 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle.hpp @@ -0,0 +1,1295 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmSplitKMultipleD_xdl_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = GridwiseGemmPipeline_v1; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, src of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, src of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ABDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // A desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AKB_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, + const int split_k) + { + const auto MRaw = a_grid_desc_m_k.GetLength(I0); + const auto KRaw = a_grid_desc_m_k.GetLength(I1); + + const index_t AK0 = + (math::integer_divide_ceil(KRaw, KPerBlock * split_k) * KPerBlock) / AK1; + const index_t K = split_k * AK0 * AK1; + const auto KPad = K - KRaw; + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(MRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(split_k, AK0, AK1)), + make_pass_through_transform(MRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + + // B desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BKB_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, + const int split_k) + { + const auto NRaw = b_grid_desc_n_k.GetLength(I0); + const auto KRaw = b_grid_desc_n_k.GetLength(I1); + + const index_t BK0 = + (math::integer_divide_ceil(KRaw, KPerBlock * split_k) * KPerBlock) / BK1; + const index_t K = split_k * BK0 * BK1; + const auto KPad = K - KRaw; + + const auto b_grid_desc_n_kpad = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_pass_through_transform(NRaw), make_right_pad_transform(KRaw, KPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_n_kpad, + make_tuple(make_unmerge_transform(make_tuple(split_k, BK0, BK1)), + make_pass_through_transform(NRaw)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const EGridDescriptor_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDescriptor_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n, const int split_k) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + e_grid_desc_m_n, 8, split_k); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDesc_M_N& ds_grid_desc_m_n, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& block_2_etile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_akb_ak0_m_ak1.GetLength(I2); + const auto N = b_grid_desc_bkb_bk0_n_bk1.GetLength(I2); + const auto K = + a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3); + + if(K != b_grid_desc_bkb_bk0_n_bk1.GetLength(I1) * b_grid_desc_bkb_bk0_n_bk1.GetLength(I3)) + { + return false; + } + if(a_grid_desc_akb_ak0_m_ak1.GetLength(I0) != b_grid_desc_bkb_bk0_n_bk1.GetLength(I0)) + { + return false; + } + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1))) + { + return false; + } + + bool valid = true; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) && + N == ds_grid_desc_m_n[i].GetLength(I1)); + }); + + if(!valid) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // check block-to-E-tile + if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DefaultAGridDesc_AK0_M_AK1 = + remove_cvref_t; + using DefaultBGridDesc_BK0_N_BK1 = + remove_cvref_t; + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(block_work_idx[Number<0>{}] == 0) + { + Run0(p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run1(p_a_grid, + p_b_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + a_grid_desc_akb_ak0_m_ak1, + b_grid_desc_bkb_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } + template + __device__ static void Run0(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t k_batch_id = block_work_idx[I0]; + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_akb_ak0_m_ak1 = + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bkb_bk0_n_bk1 = + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_akb_ak0_m_ak1), + decltype(a_block_desc_akb_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_akb_ak0_m_ak1, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_akb_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bkb_bk0_n_bk1), + decltype(b_block_desc_bkb_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bkb_bk0_n_bk1, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bkb_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataType, + ABDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_akb_ak0_m_ak1, + a_block_desc_akb_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bkb_bk0_n_bk1, + b_block_desc_bkb_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + { + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t, + uniform_sequence_gen_t< + NumDTensor, + false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + } + + template + __device__ static void Run1(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AGridDesc_AKB_AK0_M_AK1& a_grid_desc_akb_ak0_m_ak1, + const BGridDesc_BKB_BK0_N_BK1& b_grid_desc_bkb_bk0_n_bk1, + const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock&, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_akb_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bkb_bk0_n_bk1.GetElementSpaceSize()); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_etile_map.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t k_batch_id = block_work_idx[I0]; + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto a_block_desc_akb_ak0_m_ak1 = + GetABlockDescriptor_AKB_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + constexpr auto b_block_desc_bkb_bk0_n_bk1 = + GetBBlockDescriptor_BKB_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_akb_ak0_m_ak1), + decltype(a_block_desc_akb_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_akb_ak0_m_ak1, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_akb_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bkb_bk0_n_bk1), + decltype(b_block_desc_bkb_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bkb_bk0_n_bk1, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bkb_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ABDataType, + ABDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_v1_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_akb_ak0_m_ak1.GetLength(I1) * a_grid_desc_akb_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_akb_ak0_m_ak1, + a_block_desc_akb_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bkb_bk0_n_bk1, + b_block_desc_bkb_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + { + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + ck::tensor_operation::element_wise::PassThrough, // ElementwiseOperation, + EGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + EDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + ck::tensor_operation::element_wise::PassThrough{}}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp new file mode 100755 index 000000000000..f64838ea4e9f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_split_k_multiple_d_xdl_cshuffle_v2.hpp @@ -0,0 +1,1090 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +// GEMM: +// input : A[M, K] +// input : B[N, K] +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct GridwiseGemmMultipleD_xdl_splitk_cshuffle +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, AK0PerBlock, Number{}, AK1), + make_tuple(AK0PerBlock * Number{} * AK1, + Number{} * AK1, + AK1, + I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(I1, BK0PerBlock, Number{}, BK1), + make_tuple(BK0PerBlock * Number{} * BK1, + Number{} * BK1, + BK1, + I1)); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + // ck::Tuple + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch) + { + return math::integer_least_multiple(K, KPerBlock * K_Batch); + } + + template + __host__ __device__ static auto + MakeAGridDescriptor_KBatch_AK0_M_AK1(index_t M, index_t K, index_t StrideA, index_t KBatch) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto MPad = CalculateMPadded(M); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto AK0 = KPad / (KBatch * AK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + template + __host__ __device__ static auto + MakeBGridDescriptor_KBatch_BK0_N_BK1(index_t K, index_t N, index_t StrideB, index_t KBatch) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto NPad = CalculateNPadded(N); + const auto KPad = CalculateKPadded(K, KBatch); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto BK0 = KPad / (KBatch * BK1); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // Ds desc for source in blockwise copy + template + __host__ __device__ static constexpr auto + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n) + { + return generate_tuple( + [&](auto i) { + return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]); + }, + Number{}); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + e_grid_desc_m_n); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch) + { + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + ignore = StrideDs; + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + +#if 0 + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } +#endif + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + + template + __host__ __device__ static auto + MakeDsGridDescriptor_M_N(const std::array& MRaws, + const std::array& NRaws, + const std::array& DsStride) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + + return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); + }, + Number{}); + } + + __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; } + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + DsGridPointer p_ds_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const index_t KBatch, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation_& cde_element_op, + const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, + const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, + const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + ds_grid_desc_mblock_mperblock_nblock_nperblock, + const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_kbatch_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_kbatch_bk0_n_bk1.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], + ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); + }, + Number{}); + + auto e_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_kbatch_ak0_m_ak1 = + GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_kbatch_bk0_n_bk1 = + GetBBlockDescriptor_KBatch_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ComputeType, + decltype(a_grid_desc_kbatch_ak0_m_ak1), + decltype(a_block_desc_kbatch_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_kbatch_ak0_m_ak1, + make_multi_index(kbatch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_kbatch_ak0_m_ak1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + ComputeType, + decltype(b_grid_desc_kbatch_bk0_n_bk1), + decltype(b_block_desc_kbatch_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<2, 0, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_kbatch_bk0_n_bk1, + make_multi_index(kbatch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_kbatch_bk0_n_bk1, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + +#if 1 + if(block_work_idx[I0] == 0) + { + const index_t nThreadSize = CDEShuffleBlockTransferScalarPerVector_NPerBlock; + const index_t numNThreads = NPerBlock / nThreadSize; + const index_t numMThreads = BlockSize / numNThreads; + const index_t mThreadSize = MPerBlock / numMThreads; + + const index_t m_tid = get_thread_local_1d_id() / numNThreads; + const index_t n_tid = get_thread_local_1d_id() % numNThreads; + + auto c_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + StaticBuffer + e_thread_zero_buf; + + auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3< + EDataType, + EDataType, + decltype(c_thread_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + ck::tensor_operation::element_wise::PassThrough, + Sequence<1, mThreadSize, 1, nThreadSize>, + Sequence<0, 1, 2, 3>, + 3, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], + m_tid * mThreadSize, + block_work_idx[I2], + n_tid * nThreadSize), + ck::tensor_operation::element_wise::PassThrough{}}; + + c_thread_copy.Run(c_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + e_thread_zero_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_buf); + + __syncthreads(); + + if(threadIdx.x == 0) + { + atomicAdd(barrier_count_finished, 1); + } + } +#endif + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = + __builtin_amdgcn_readfirstlane((a_grid_desc_kbatch_ak0_m_ak1.GetLength(I1) * + a_grid_desc_kbatch_ak0_m_ak1.GetLength(I3)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_kbatch_ak0_m_ak1, + a_block_desc_kbatch_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_kbatch_bk0_n_bk1, + b_block_desc_kbatch_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + if(threadIdx.x == 0) + { + while(__atomic_load_n(barrier_count_finished, __ATOMIC_RELAXED) == 0) {} + } + + __syncthreads(); + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); + }, + Number{})); + + // space filling curve for threadwise C in VGPR before shuffle + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + // blockwise copy C/D/E between LDS and global + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CDEElementwiseOperation_, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0)), + cde_element_op}; + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(e_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor_, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + + if(threadIdx.x == 0) + { + index_t k_id_finished_t = atomicAdd(barrier_count_finished, 1); + + if(k_id_finished_t == KBatch) + { + *barrier_count_finished = 0; + } + } + } + } + + template + __device__ static void Run(const void* __restrict__ p_a_grid_, + const void* __restrict__ p_b_grid_, + DsGridPointer p_ds_grid, + void* __restrict__ p_e_grid_, + void* __restrict__ p_shared, + uint32_t* barrier_count_finished, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CDEElementwiseOperation& cde_element_op, + const index_t M, + const index_t N, + const index_t K, + const index_t StrideA, + const index_t StrideB, + const std::array StrideDs, + const index_t StrideE, + const index_t KBatch, + const Block2ETileMap& block_2_etile_map) + { + const auto p_a_grid = reinterpret_cast(p_a_grid_); + const auto p_b_grid = reinterpret_cast(p_b_grid_); + const auto p_e_grid = reinterpret_cast(p_e_grid_); + + using DsGridDesc_M_N = + remove_cvref_t({}, {}, {}))>; + + DsGridDesc_M_N ds_grid_desc_m_n; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + using DLayout = remove_cvref_t>; + + ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N(M, N, StrideDs[j]); + }); + + const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N(M, N, StrideE); + + // tensor descriptors for block/thread-wise copy + const auto a_grid_desc_kbatch_ak0_m_ak1 = + MakeAGridDescriptor_KBatch_AK0_M_AK1(M, K, StrideA, KBatch); + + const auto b_grid_desc_kbatch_bk0_n_bk1 = + MakeBGridDescriptor_KBatch_BK0_N_BK1(K, N, StrideB, KBatch); + + using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_mblock_mperblock_nblock_nperblock(j) = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[j]); + }); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); + + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + if(kbatch_id == KBatch - 1) + { + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + else + { + Run>( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + barrier_count_finished, + KBatch, + a_element_op, + b_element_op, + ck::tensor_operation::element_wise::PassThrough{}, + a_grid_desc_kbatch_ak0_m_ak1, + b_grid_desc_kbatch_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp new file mode 100755 index 000000000000..de5a4241986f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" + +namespace ck { + +template +struct GridwiseGemmLoadWave; + +// 1-stage prefetch +template +struct GridwiseGemmLoadWave +{ + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) + { + // TODO: improve applicability + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + static __device__ void RunLoadWavePipeline(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + index_t num_loop) + { + // global read 0 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to 1 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // LDS write 0 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + // sync for Load threads() + block_sync_lds(); + // global read i + 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + // move to i + 2 + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // sync with math threads() + block_sync_lds(); + + // LDS write i+1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + // GEMM num_loop - 1 + } + } +}; + +template +struct GridwiseGemmMathWave; +// 1- stage prefetch +template +struct GridwiseGemmMathWave +{ + + __host__ __device__ static constexpr bool IsSupported(index_t /* num_loop */) { return true; } + + __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + { + return num_loop > 1; + } + + template + static __device__ void RunMathWavePipeline(ABlockBuffer& a_block_buf, + BBlockBuffer& b_block_buf, + const BlockwiseGemm& block_gemm, + CThreadBuffer& c_thread_buf, + index_t num_loop) + { + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + + do + { + block_sync_lds(); + + // GEMM i + block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + ++i; + } while(i < (num_loop - 1)); + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 1 + block_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp new file mode 100755 index 000000000000..4458b9356dd6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -0,0 +1,1023 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_wmma(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + const AGridDesc a_grid_desc, + const BGridDesc b_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) + __shared__ char p_shared[GridwiseGemm::SharedMemTrait::lds_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc, + b_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc; + ignore = b_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx11__)) +} + +template +struct GridwiseGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // FIX ME: To be deprecated + static constexpr auto K1 = Number{}; + + static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma); + static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma); + static constexpr auto WmmaK = K1 == 16 ? 32 : 16; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = + remove_cvref_t())>; + + // Describe how data store to (LDS/VGPR) buffer from Global memory + __host__ __device__ static constexpr auto MakeABlockDescriptor() + { + constexpr auto a_block_desc = [&]() { + if constexpr(AEnableLds) + { + // K0->M->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + constexpr auto A_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / A_KRow / K1; + // KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return a_block_desc; + } + + __host__ __device__ static constexpr auto MakeBBlockDescriptor() + { + constexpr auto b_block_desc = [&]() { + if constexpr(BEnableLds) + { + // K0->N->K1 Per Block + constexpr auto K0PerBlock = KPerBlock / K1; + constexpr auto max_lds_align = K1; + + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + } + else + { + + constexpr auto B_KRow = I2; + constexpr auto KWmmaPerblock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK / B_KRow / K1; + // KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread + return make_naive_tensor_descriptor( + make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + I1, + K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + Number{} * K1, + K1, + K1, + K1, + I1)); + } + }(); + + return b_block_desc; + } + + __host__ __device__ static constexpr auto MakeABlockSliceCopyStep() + { + constexpr auto a_block_copy_step = [&]() { + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return a_block_copy_step; + } + + __host__ __device__ static constexpr auto MakeBBlockSliceCopyStep() + { + constexpr auto b_block_copy_step = [&]() { + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock / K1; + + return make_multi_index(K0PerBlock, 0, 0); + } + else + { + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + + return make_multi_index(KWmmaPerBlock, 0, 0, 0, 0, 0, 0); + } + }(); + + return b_block_copy_step; + } + + // Describe how data read from (LDS/VGPR) buffer + template + __host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&) + { + + constexpr auto a_wave_desc = [&]() { + if constexpr(AEnableLds) + { + // AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1 + constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto A_KRow = I2; +#else + constexpr auto A_KRow = I1; +#endif + + return transform_tensor_descriptor( + ABlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, A_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = ABlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = ABlockDesc_{}.GetLength(I3); + constexpr auto A_KRow = ABlockDesc_{}.GetLength(I4); + constexpr auto A_K1 = ABlockDesc_{}.GetLength(I6); + + // Err: merge transform cause non-constexpr issue + + // return transform_tensor_descriptor( + // ABlockDesc_{}, + // make_tuple(make_merge_transform(make_tuple(Number{}, I1)), + // make_pass_through_transform(Number{}), + // make_pass_through_transform(I1), + // make_pass_through_transform(I1), + // make_pass_through_transform(Number{})), + // make_tuple(Sequence<0, 3>{}, + // Sequence<1>{}, + // Sequence<2>{}, + // Sequence<4>{}, + // Sequence<5>{}), + // make_tuple( + // Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, + // Sequence<4>{})); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return a_wave_desc; + } + + template + __host__ __device__ static constexpr auto MakeBWaveDescriptor(const BBlockDesc_&) + { + constexpr auto b_wave_desc = [&]() { + if constexpr(BEnableLds) + { + // BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1 + constexpr auto B_K0 = BBlockDesc_{}.GetLength(I0); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto B_KRow = I2; +#else + constexpr auto B_KRow = I1; +#endif + return transform_tensor_descriptor( + BBlockDesc_{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, B_KRow)), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + else + { + // KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1 + constexpr auto KWmma = BBlockDesc_{}.GetLength(I0); + constexpr auto K0PerWmma = BBlockDesc_{}.GetLength(I3); + constexpr auto B_KRow = BBlockDesc_{}.GetLength(I4); + constexpr auto B_K1 = BBlockDesc_{}.GetLength(I6); + + // Workaround, Freeze transform + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + I1, + Number{}, + I1, + Number{})); + } + }(); + + return b_wave_desc; + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool CheckValidity(const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerWmma)) == 0, + "Invalid tuning param!"); + + const auto GetAProblemsizeMK = [&]() { + if constexpr(AEnableLds) + { + return make_tuple(a_grid_desc.GetLength(I1), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(a_grid_desc.GetLength(I1) * a_grid_desc.GetLength(I2) * + a_grid_desc.GetLength(I5), + a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) * + a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6)); + } + }; + + const auto GetBProblemsizeNK = [&]() { + if constexpr(BEnableLds) + { + return make_tuple(b_grid_desc.GetLength(I1), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I2)); + } + else + { + return make_tuple(b_grid_desc.GetLength(I1) * b_grid_desc.GetLength(I2) * + b_grid_desc.GetLength(I5), + b_grid_desc.GetLength(I0) * b_grid_desc.GetLength(I3) * + b_grid_desc.GetLength(I4) * b_grid_desc.GetLength(I6)); + } + }; + + const auto M = GetAProblemsizeMK()[I0]; + const auto N = GetBProblemsizeNK()[I0]; + const auto K = GetAProblemsizeMK()[I1]; + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K == GetBProblemsizeNK()[I1])) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + return false; + } + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + printf("GridwiseOp err: ProblemSize division"); + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB && + b_grid_desc.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB)) + { + return false; + } + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + struct SharedMemTrait + { + // LDS allocation for A and B: be careful of alignment + + static constexpr auto max_lds_align = K1; + + static constexpr auto a_block_space_size_aligned = + AEnableLds ? math::integer_least_multiple(MakeABlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + static constexpr auto b_block_space_size_aligned = + BEnableLds ? math::integer_least_multiple(MakeBBlockDescriptor().GetElementSpaceSize(), + max_lds_align) + : 0; + + static constexpr auto a_block_space_offset = 0; + static constexpr auto b_block_space_offset = a_block_space_size_aligned; + + // LDS allocation for C shuffle in LDS + static constexpr auto c_shuffle_block_space_size = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + .GetElementSpaceSize(); + + static constexpr auto c_shuffle_block_space_offset = 0; + + static constexpr auto lds_size = + math::max(c_shuffle_block_space_size * sizeof(CShuffleDataType), + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)); + }; + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void Run(const ADataType* __restrict__ p_a_grid, + const BDataType* __restrict__ p_b_grid, + CDataType* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc& a_grid_desc, + const BGridDesc& b_grid_desc, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + // clang-format off +/*******************************************************************************/ +// Memory buffer zone. + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + +/*******************************************************************************/ +// BlockIdx.x -> [BlockId.m, BlockId.n] + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { return; } + + // Store BlockId into SGPR + const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + +/*******************************************************************************/ +// BlockLevel, A/B Matrix ThreadMapping in WMMA Source buffer, As Destinaion of BlockWise_Copy + const auto K = [&](){ + if constexpr(AEnableLds){ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I2); + } + else{ + return a_grid_desc.GetLength(I0) * a_grid_desc.GetLength(I3) + * a_grid_desc.GetLength(I4) * a_grid_desc.GetLength(I6); + } + }(); + + constexpr auto a_block_desc = MakeABlockDescriptor(); + constexpr auto b_block_desc = MakeBBlockDescriptor(); + + auto a_block_trait = [&](){ + // A matrix blockwise copy + if constexpr(AEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + SharedMemTrait::a_block_space_size_aligned); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, +/* typename ThreadClusterLengths, */ ABlockTransferThreadClusterLengths_K0_M_K1, +/* typename ThreadClusterArrangeOrder, */ ABlockTransferThreadClusterArrangeOrder, +/* typename SrcData, */ ADataType, +/* typename DstData, */ ADataType, +/* typename SrcDesc, */ decltype(a_grid_desc), +/* typename DstDesc, */ decltype(a_block_desc), +/* typename SrcDimAccessOrder, */ ABlockTransferSrcAccessOrder, +/* typename DstDimAccessOrder, */ Sequence<0, 1, 2>, +/* index_t SrcVectorDim, */ ABlockTransferSrcVectorDim, +/* index_t DstVectorDim, */ 2, +/* index_t SrcScalarPerVector, */ ABlockTransferSrcScalarPerVector, +/* index_t DstScalarPerVector, */ ABlockTransferDstScalarPerVector_K1, +/* index_t SrcScalarStrideInVector, */ 1, +/* index_t DstScalarStrideInVector, */ 1, +/* bool ThreadTransferSrcResetCoordinateAfterRun, */ AThreadTransferSrcResetCoordinateAfterRun, +/* bool ThreadTransferDstResetCoordinateAfterRun, */ true, + NumGemmKPrefetchStage>( + a_grid_desc, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> MRepeat -> MWaves -> K0PerWmma -> KRow -> MPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto a_block_buf = make_static_buffer( + a_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto a_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + ABlockTransferSrcScalarPerVector, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc, + make_multi_index(0, + m_block_data_idx_on_grid/(MWaves * MPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(a_block_buf, a_blockwise_copy); + } + }; + + auto b_block_trait = [&](){ + if constexpr(BEnableLds) + { + constexpr auto K0PerBlock = KPerBlock/ K1; + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::b_block_space_offset, + SharedMemTrait::b_block_space_size_aligned); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc), + decltype(b_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + else + { + // Thread-wise copy + // KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1 + constexpr auto KWmmaPerBlock = KPerBlock / WmmaK; + constexpr auto K0PerWmma = WmmaK/2/K1Value; + auto b_block_buf = make_static_buffer( + b_block_desc.GetElementSpaceSize()); + + // Limitation: NumDim of Src and Dst descriptor should be identical + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + Number{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc, + make_multi_index(0, + n_block_data_idx_on_grid/(NWaves * NPerWmma), + get_thread_local_1d_id() / 32, + 0, + (get_thread_local_1d_id() % 32 )/ 16, + get_thread_local_1d_id() % 16, + 0)); + + return make_tuple(b_block_buf, b_blockwise_copy); + } + }; + + auto a_block_buf = a_block_trait()[I0]; + auto a_blockwise_copy = a_block_trait()[I1]; + + auto b_block_buf = b_block_trait()[I0]; + auto b_blockwise_copy = b_block_trait()[I1]; +/*******************************************************************************/ + // GEMM + constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); + + auto blockwise_gemm = + BlockwiseGemmWMMA{}; + + // Prepare Register for C matrix + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + +/*******************************************************************************/ + // Shift Per SUB_K + constexpr auto a_block_slice_copy_step = MakeABlockSliceCopyStep(); + constexpr auto b_block_slice_copy_step = MakeBBlockSliceCopyStep(); + + // gridwise GEMM pipeline + const index_t KBlockMainLoop = __builtin_amdgcn_readfirstlane(K / KPerBlock); + GridwiseGemmPipe::template Run(a_grid_desc, + a_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc, + b_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + KBlockMainLoop); +/*******************************************************************************/ + // write out to C, implement shuffle + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm.GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm.GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I1); + constexpr auto MSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I2); + constexpr auto NWave = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I4); + constexpr auto NThreadPerSubGroup = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I5); + constexpr auto MAccVgprs = c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp.GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared) + SharedMemTrait::c_shuffle_block_space_offset, + SharedMemTrait::c_shuffle_block_space_size); + + constexpr auto c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0, 1, 2, 6>{}, Sequence<>{}, Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_idx = n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // clang-format on + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp new file mode 100755 index 000000000000..75f12d094e3c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { + +/// @brief \"Universal\" GEMM kernel with SplitK support. +/// +/// @par Overview +/// This GEMM kernel is carrying out following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations that could be applied on each tensor respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam CDataType C tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1Value The vector load size from global memory for A tensor. +/// @tparam BK1Value The vector load size from global memory for B tensor. +/// @tparam MPerWmma M size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam NPerWmma N size of Wave Matrix Multiply Accumulate (WMMA) instruction. +/// @tparam MRepeat The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NRepeat The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With universal GEMM +/// there's no need for padding. +/// @tparam CShuffleMRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNRepeatPerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +template +struct GridwiseGemm_wmma_cshuffle_v3 + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> +{ + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; + + using ThisThreadBlock = ThisThreadBlock; + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t c_reduce_offset; + }; + + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + __device__ static index_t GetKBlockPerScale() { return 1; } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100755 index 000000000000..7b6ad5ca3e15 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,551 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { + +template +struct GridwiseGemm_wmma_cshuffle_v3_b_scale + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> +{ + using BScaleType = ck::half_t; + + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; + + using ThisThreadBlock = ThisThreadBlock; + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + StrideScaleB{StrideScaleB_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "SScaleB:" << StrideScaleB << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t StrideScaleB; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + + const BScaleType* p_b_scale_grid; + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB; + } + else if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const BScaleType* p_b_scale_grid, + index_t block_n_id) + { + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + static constexpr auto wmma = + WmmaSelector{}; + static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma; + + static constexpr auto ScaleSliceSizeN = NRepeat; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma + + (get_thread_local_1d_id() / 32) % NWaves * NPerWmma; + auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + using BScale = + typename BlockwiseGemmPipe::template BScale; + + return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + } + + __device__ static index_t GetKBlockPerScale() + { + return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // B Scale grid + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct + auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id); + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // NOTE: Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp new file mode 100755 index 000000000000..5a4a41e50703 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,1420 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_wmma_cshuffle_v3_base +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + WmmaSelector::selected_wmma + .k_per_wmma); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + static_assert(!PermuteA, "PermuteA is not supported"); + + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + + return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // TODO: Investigate why this path is not used in the original + // gridwise_gemm_xdl_cshuffle_v3.hpp +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerWmma; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerWmma; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + ComputeTypeB, + AccDataType, + decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), + decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>())>; + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NPerWmma * NRepeat)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t& block_m_id, + const index_t& block_n_id, + const index_t& num_k_block_per_scale, + BScaleStruct& b_scale_struct) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_struct, + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm_pipeline + .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm_pipeline + .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 1, 2, 6>{}, + Sequence<>{}, + Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor + .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp new file mode 100755 index 000000000000..63d40f6ff8e8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -0,0 +1,1384 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemm_xdl_cshuffle_conv_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id = 0) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( + make_multi_index(static_cast(blockIdx.x))); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k_id, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k_id, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + (KPerBlock * problem.KBatch)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp new file mode 100755 index 000000000000..d45ed79ae312 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp @@ -0,0 +1,2735 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/utility/workgroup_barrier.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg, karg.p_workspace_); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.p_workspace_); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_streamk_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + // Pad both M and K to be multiples of the block sizes + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } +#endif + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + // Pad both N and K to be multiples of the block sizes + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } +#endif + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // Pad both M and N to be multiples of the block sizes + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_, + StreamKReductionStrategy reduction_strategy_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + Streamk_sel{Streamk_sel_}, + Grid_size{Grid_size_}, + reduction_strategy{reduction_strategy_}, // Initialize the member variable + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, 1)}, + KPadded{CalculateKPadded(K_, 1)}, + AK0{CalculateAK0Padded(K_, 1)}, + BK0{CalculateBK0Padded(K_, 1)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << ", " + << "Stream-K Selection:" << Streamk_sel << ", " + << "Grid size:" << Grid_size << ", " + << "Reduction Strategy:" + << (reduction_strategy == StreamKReductionStrategy::Atomic ? "Atomic" + : "Reduction") + << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t Streamk_sel; + mutable index_t Grid_size; + StreamKReductionStrategy reduction_strategy; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t Streamk_sel_, + index_t Grid_size_, + StreamKReductionStrategy reduction_strategy_) + : Problem{M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + Streamk_sel_, + Grid_size_, + reduction_strategy_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + block_2_ctile_map_streamk(M_, + N_, + AK0Number * CalculateKPadded(K_, 1), + Grid_size_, + Streamk_sel_, + reduction_strategy_) + + { + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + BlockToCTileMap_GemmStreamK_v2 + block_2_ctile_map_streamk; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Problem& problem, unsigned int kbatch_id, unsigned int orig_K) + { + if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = kbatch_id * problem.KRead * problem.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead * problem.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = kbatch_id * problem.KRead; + } + + if(kbatch_id < static_cast(problem.KBatch - 1)) + { + problem.K = problem.KRead; + } + else + { + problem.K = orig_K - problem.KRead * (problem.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(ADataType); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(BDataType); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + + if(karg.K <= 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr auto GetClusterLengthReduction() + { + // TODO: assume C is row major + // TODO: we always first loop over N, then M + constexpr auto NPerBlockPow2 = math::next_power_of_two(); + constexpr auto NPerBlockReduction = + NPerBlockPow2 / CShuffleBlockTransferScalarPerVector_NPerBlock; + constexpr auto MPerBlockReduction = + (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction; + return Sequence{}; + } + + __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor() + { + const auto c_partial_acc_block_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(NPerBlock, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(I1, MPerBlock)); + } + }(); + return c_partial_acc_block_m_n; + } + using Block2CTileMap_streamk = BlockToCTileMap_GemmStreamK_v2; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + Problem& problem, + void* p_workspace) + { + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.Grid_size, + problem.Streamk_sel, + problem.reduction_strategy); + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block, is_reduction_block; + index_t num_k_block_main_loop; + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + uint32_t* p_semaphore = reinterpret_cast( + reinterpret_cast(p_workspace) + + block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); + + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + { + is_reduction_block = static_cast(block_idx) >= + block_2_ctile_map_streamk.reduction_start_block_idx; + if(is_reduction_block) + { + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(block_idx)); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + + constexpr auto MReduceIters = math::integer_divide_ceil( + Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); + + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = + make_naive_tensor_descriptor_packed(make_tuple( + I1, I1, I1, Number{})); + + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + + constexpr auto partial_acc_load_step_n = + make_multi_index(0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_n_reverse = make_multi_index( + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); + + constexpr auto partial_acc_store_step_n = + make_multi_index(0, + 0, + 0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_n_reverse = make_multi_index( + 0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; + + // start to compute + auto reduction_idx = + block_idx - block_2_ctile_map_streamk.reduction_start_block_idx; + auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial( + reduction_idx, problem.M, problem.N); + + workgroup_barrier wg_barrier(p_semaphore); + + uint32_t tile_acc_offset_start = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx + + 1); + __syncthreads(); + + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + AccDataType, // SrcData, + AccDataType, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock)}; + + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, // SrcData, + CDataType, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, + 1, + 1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock), + CElementwiseOperation{}}; + + wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); + + if(threadIdx.x == 0) + { + p_semaphore[reduction_idx] = 0; + } + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) + { + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); + } + + if(thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); + } + } + + continue; + } + } + + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * + MPerBlock * NPerBlock; + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + AccDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + if(problem.reduction_strategy == StreamKReductionStrategy::Atomic) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(problem.reduction_strategy == + StreamKReductionStrategy::Reduction) + { + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(MXdlPerWave, 0, NXdlPerWave, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + c_partial_acc_buf); + } + } + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + + if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(tile_idx); + } + } + } // shuffle c and write-out end + + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + { + block_acc_offset -= MPerBlock * NPerBlock; + } + // make sure next loop LDS is ready for use + block_sync_lds(); + } // while loop + + } // for loop + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + Problem& problem, + void* p_workspace) + { + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + uint32_t iter_start, iter_end; + bool is_sk_block, is_dp_block, is_reduction_block; + index_t num_k_block_main_loop; + + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + Block2CTileMap_streamk block_2_ctile_map_streamk(problem.M, + problem.N, + AK0Number * problem.KPadded, + problem.Grid_size, + problem.Streamk_sel, + problem.reduction_strategy); + for(auto block_idx = get_block_1d_id(); + block_idx < block_2_ctile_map_streamk.get_grid_dims(); + block_idx += gridDim.x) + { + is_sk_block = + static_cast(block_idx) < block_2_ctile_map_streamk.sk_num_blocks; + is_dp_block = + static_cast(block_idx) >= block_2_ctile_map_streamk.dp_start_block_idx && + static_cast(block_idx) < + block_2_ctile_map_streamk.reduction_start_block_idx; + + block_2_ctile_map_streamk.get_block_itr(block_idx, iter_start, iter_end); + num_k_block_main_loop = iter_end - iter_start; + + uint32_t* p_semaphore = reinterpret_cast( + reinterpret_cast(p_workspace) + + block_2_ctile_map_streamk.get_workspace_size_for_acc(sizeof(AccDataType))); + + if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + { + is_reduction_block = static_cast(block_idx) >= + block_2_ctile_map_streamk.reduction_start_block_idx; + if(is_reduction_block) + { + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(block_idx)); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + + constexpr auto MReduceIters = math::integer_divide_ceil( + Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); + + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = + make_naive_tensor_descriptor_packed(make_tuple( + I1, I1, I1, Number{})); + + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + + constexpr auto partial_acc_load_step_n = + make_multi_index(0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_n_reverse = make_multi_index( + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); + + constexpr auto partial_acc_store_step_n = + make_multi_index(0, + 0, + 0, + cluster_length_reduce.At(I1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_n_reverse = make_multi_index( + 0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CShuffleBlockTransferScalarPerVector_NPerBlock); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; + + // start to compute + auto reduction_idx = + block_idx - block_2_ctile_map_streamk.reduction_start_block_idx; + auto spatial_idx = block_2_ctile_map_streamk.tile_to_spatial( + reduction_idx, problem.M, problem.N); + + workgroup_barrier wg_barrier(p_semaphore); + + uint32_t tile_acc_offset_start = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_2_ctile_map_streamk.get_acc_buffer_offset_from_tile(reduction_idx + + 1); + + uint32_t expected_count = tile_acc_offset_end - tile_acc_offset_start; + + if(threadIdx.x == 0) + { + p_semaphore[reduction_idx] = 0; + } + + __syncthreads(); + + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + AccDataType, // SrcData, + AccDataType, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock)}; + + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, // SrcData, + CDataType, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, + 1, + 1, + CShuffleBlockTransferScalarPerVector_NPerBlock>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock), + CElementwiseOperation{}}; + +#if 0 + if(threadIdx.x == 0) { + printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast(blockIdx.x), + reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end), + __builtin_amdgcn_readfirstlane(spatial_idx[I0]), + __builtin_amdgcn_readfirstlane(spatial_idx[I1])); + } +#endif + if(threadIdx.x == 0) + { + atomicAdd(&p_semaphore[reduction_idx], 1); + } + + wg_barrier.wait_eq(p_semaphore[reduction_idx], expected_count); + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) + { + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CShuffleBlockTransferScalarPerVector_NPerBlock, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); + } + + if(thread_n_cluster_id * + CShuffleBlockTransferScalarPerVector_NPerBlock < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); + } + } + + continue; + } + } + + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_2_ctile_map_streamk.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * + MPerBlock * NPerBlock; + while(true) + { + + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_2_ctile_map_streamk.get_current_iter_length( + iter_start, iter_end, num_k_block_main_loop)); + uint32_t tile_idx, iter_offset; + block_2_ctile_map_streamk.get_tile_idx_with_offset( + iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + + auto block_work_idx = + block_2_ctile_map_streamk.tile_to_spatial(tile_idx, problem.M, problem.N); + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * AK0Number); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = + GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = + GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + BElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), + a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = + make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = + make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock + .GetElementSpaceSize()); + + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + // CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * + NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + AccDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be + // false, othre wise has scratch + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = SpaceFillingCurve< + Sequence<1, MPerBlock, 1, NPerBlock>, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + c_shuffle_block_copy_lds_to_global.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + if(is_dp_block) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(is_sk_block) + { + if(problem.reduction_strategy == StreamKReductionStrategy::Atomic) + { + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + else if(problem.reduction_strategy == + StreamKReductionStrategy::Reduction) + { + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(MXdlPerWave, 0, NXdlPerWave, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + c_partial_acc_buf); + } + } + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + { + block_acc_offset -= MPerBlock * NPerBlock; + } + // make sure next loop LDS is ready for use + block_sync_lds(); + } + if(problem.reduction_strategy == StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(0); + } + } + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp new file mode 100755 index 000000000000..7edcd7270fcb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp @@ -0,0 +1,1084 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + typename GridwiseGemm::Problem problem) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, p_b_grid, p_c_grid, p_shared, problem); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = problem; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + }; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + template + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + ComputeTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + ComputeTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeTypeA, + ComputeTypeB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp new file mode 100755 index 000000000000..f92268265f16 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp @@ -0,0 +1,1167 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1) +#endif + kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + typename GridwiseGemm::Problem problem) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = problem; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + return CalculateKPadded(K) / AK1Value; + } + else + { + return K / AK1Value; + } + } + + __host__ static auto CalculateBK0(index_t K) + { + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + return CalculateKPadded(K) / BK1Value; + } + else + { + return K / BK1Value; + } + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_floor(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_floor(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KPadded{CalculateKPadded(K_)}, + AK0{CalculateAK0(K_)}, + BK0{CalculateBK0(K_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + }; + + // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) + + b_block_space_size_aligned * sizeof(ComputeTypeB)), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding) + { + if(!(CalculateKPadded(problem.K) % AK1Value == 0) || + !(CalculateKPadded(problem.K) % BK1Value == 0)) + { + return false; + } + } + else + { + if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + else + { + if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock; + + if(num_k_loop < 4) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return num_loop > 3; + } + + __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + if(num_loop % 2 == 1) + return 3; + else + return 2; + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const FloatA* p_a_grid, + const FloatB* p_b_grid, + FloatC* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } +#if 0 + if(threadIdx.x == 0){ + printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n", + get_block_1d_id(), + block_work_idx[I0], + block_work_idx[I1], + __smid()>>6 & 0xf, + __smid()>>4 & 0x3, + __smid() & 0xf); + } +#endif + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + ComputeTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + ComputeTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + // BlockSize, + // ComputeType, + // FloatGemmAcc, + // decltype(a_block_desc_ak0_m_ak1), + // decltype(b_block_desc_bk0_n_bk1), + // MPerXdl, + // NPerXdl, + // MXdlPerWave, + // NXdlPerWave, + // KPack, + // LoopSched>(); + auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4< + BlockSize, + ComputeTypeA, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)), + decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)), + MPerBlock, + NPerBlock, + KPerBlock, + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>{}; // TransposeC + + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // gridwise GEMM pipeline + static_assert(std::is_default_constructible_v); + // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp new file mode 100755 index 000000000000..6270d0c4dc7e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp @@ -0,0 +1,2202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +/// @brief \"Universal\" GEMM kernel with SplitK support. +/// +/// @par Overview +/// This GEMM kernel is carrying out following mathematical equation: +/// C{M,N} = C_op(A_op(A{M,K}) * B_op(B{K,N})) +/// Where A, B are input tensors and C is the output tensor. The A/B/C_op are +/// elementwise operations that could be applied on each tensor respectively. +/// The \"universal\" gemm comes with multiple pipelines optimized for different usage +/// scenarios. That's why it's called \"universal\". It's universal through it's design +/// and versatilty. +/// +/// @note This Kernel implementation supports SplitK algorithm. It can be configured +/// to split the dot product accumulated over the K dimension into multiple working groups. +/// The partial products of different workgroups are then reduced using the AtomicAdd +/// operation. +/// +/// @tparam ALayout A tensor data layout. +/// @tparam BLayout B tensor data layout. +/// @tparam CLayout C tensor data layout. +/// @tparam ADataType A tensor data type. +/// @tparam BDataType B tensor data type. +/// @tparam AccDataType The accumulation data type related to the hardware +/// matrix-multiplication instruction. +/// @tparam CShuffleDataType The data type used to store matrix-multiplication results into +/// LDS memory during \"CShuffle\" data layout optimization. +/// @tparam CDataType C tensor data type. +/// @tparam AElementwiseOperation Elementwise operation applied to the A input tensor elements. +/// @tparam BElementwiseOperation Elementwise operation applied to the B input tensor elements. +/// @tparam CElementwiseOperation Elementwise operation applied to the C output tensor +/// (after GEMM). +/// @tparam GemmSpec Determines used "padding" version. +/// @tparam BlockSize The number of threads within workgroup. +/// @tparam MPerBlock The input/output data tile size in the M dimension. +/// @tparam NPerBlock The input/output data tile size in the N dimension. +/// @tparam KPerBlock The input data tile size in the K dimension. +/// @tparam AK1Value The vector load size from global memory for A tensor. +/// @tparam BK1Value The vector load size from global memory for B tensor. +/// @tparam MPerXdl M size of matrix-fused-multiply-add instruction. +/// @tparam NPerXdl N size of matrix-fused-multiply-add instruction. +/// @tparam MXdlPerWave The number of iterations in the M dimension over output tile per wavefront. +/// @tparam NXdlPerWave The number of iterations in the N dimension over output tile per wavefront. +/// @tparam ABlockTransferThreadClusterLengths_AK0_M_AK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question, "How many threads can be +/// arranged on each input data axis?" +/// @tparam ABlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam ABlockTransferSrcAccessOrder The order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam ABlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam ABlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam ABlockTransferDstScalarPerVector_AK1 The size of vectorized store into LDS memory. +/// @tparam AThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam ABlockLdsExtraM Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam BBlockTransferThreadClusterLengths_BK0_N_BK1 Spatial thread distribution over the input +/// data. Can be interpreted as the answer +/// to the question: "How many threads to +/// arrange on each input data axis?" +/// @tparam BBlockTransferThreadClusterArrangeOrder The order of thread spatial distribution over +/// the input tensor dimension. Can be interpreted +/// as the answer to the question: "In which +/// order to spread threads through tensor axes?". +/// @tparam BBlockTransferSrcAccessOrder he order of accessing input tensor axes. Can be +/// interpreted as the answer to the question "Which dimension +/// to read first? And which next?" etc. +/// @tparam BBlockTransferSrcVectorDim The index of axis on which we could do vectorized memory +/// access - the one with contiguous memory. +/// @tparam BBlockTransferSrcScalarPerVector The size of vector access instruction - the number of +/// elements accessed per thread per instruction. +/// @tparam BBlockTransferDstScalarPerVector_BK1 The size of vectorized store into LDS memory. +/// @tparam BThreadTransferSrcResetCoordinateAfterRun Decides whether we reset thread coordinate +/// (return back to the window origin) after all thread finish data copy. +/// @tparam BBlockLdsExtraN Whether to use padding for LDS or not. With +/// universal GEMM there's no need for padding. +/// @tparam CShuffleMXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in M dimension. +/// @tparam CShuffleNXdlPerWavePerShuffle The number of matrix-multiplication instructions +/// results to process per wave per iteration of CShuffle +/// in N dimension. +/// @tparam CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock The spatial +/// thread distribution used for storing data into output +/// tensor across output data layout dimensions. +/// @tparam CShuffleBlockTransferScalarPerVector_NPerBlock The size of vectorized memory access. +/// Used when storing data to output tensor. +/// @tparam BlkGemmPipeSched The version of blockwise-gemm pipeline scheduler (interwave or +/// intrawave). +/// @tparam BlkGemmPipelineVer The version of blockwise-gemm pipeline. +/// @tparam ComputeTypeA Data type used for A input of hardware matrix-multiplication +/// instructions. +/// @tparam ComputeTypeB Data type used for B input of hardware matrix-multiplication +/// instructions. +/// @tparam PermuteA Whether the A input tensor has gridwise-gemm friendly data layout +/// in global memory. Currently not supported! +/// @tparam PermuteB Whether the B input tensor has gridwise-gemm friendly data layout +/// in global memory (pre-shuffled). +/// @tparam DoElementwiseBeforeCShuffle Whether the cde_elementwise should be performed before or +/// after elementwise op. +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume + ((is_same::value || is_same::value) && + KPerBlock < 128 && MPerXdl == 16)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_, + bool is_reduce_ = false, + AElementwiseOperation a_element_op = AElementwiseOperation{}, + BElementwiseOperation b_element_op = BElementwiseOperation{}, + CElementwiseOperation c_element_op = CElementwiseOperation{}) + : Problem{M_, + N_, + K_, + StrideA_, + StrideB_, + StrideC_, + k_batch_, + a_element_op, + b_element_op, + c_element_op}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + problem.a_element_op_, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + problem.b_element_op_, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + tensor_operation::element_wise::PassThrough pass_through{}; + const auto& vpgr_to_lds_element_op = [&] { + if constexpr(DoElementwiseBeforeCShuffle) + { + return problem.c_element_op_; + } + else + { + return pass_through; + } + }; + const auto& lds_to_global_element_op = [&] { + if constexpr(!DoElementwiseBeforeCShuffle) + { + return problem.c_element_op_; + } + else + { + return pass_through; + } + }; + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + conditional_t, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + vpgr_to_lds_element_op()}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + conditional_t, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + lds_to_global_element_op()}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + problem.a_element_op_, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + problem.b_element_op_, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + problem.c_element_op_}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp new file mode 100755 index 000000000000..8d5c844103f8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -0,0 +1,1882 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + // Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + // See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3 + // TODO: explore optimization opportunity by using new mfma instructions on gfx950 + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = true; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); + + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + // constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + // return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // For B pre-shuffle only + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t k_batch_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + // b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + + b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + I1, + Number{}, + Number{})); //??? BK1Value same as KPack? + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType) / APackedSize, + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + // const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix threadwise copy, using threadwiseTensorSliceTransfer_v2 + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bpreshuffled, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + // const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + 2>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bpreshuffled, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp new file mode 100755 index 000000000000..93c1779a80d4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp @@ -0,0 +1,2226 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + using BScaleType = ck::half_t; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + StrideScaleB{StrideScaleB_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "SScaleB:" << StrideScaleB << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t StrideScaleB; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + + const BScaleType* p_b_scale_grid; + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset / BPackedSize; + } + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB; + } + else if constexpr(is_same_v) + { + scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK); + } + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // b scale + // static_assert(KPerBlock <= ScaleBlockK); + static constexpr auto mfma = + MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + static constexpr auto ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock; + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // B Scale grid + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + Run(p_a_grid, + p_b_grid, + p_c_grid, + p_b_scale_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // B scale + static constexpr auto mfma = + MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + const index_t ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + static constexpr auto KBlockScaleSliceSizeK = (KPerBlock + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + (get_thread_local_1d_id() / 64) % NWaves * NPerXdl; + auto b_thread_offset_k = (get_thread_local_1d_id() % 64) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + const index_t num_k_block_per_scale = (ScaleBlockK + KPerBlock - 1) / KPerBlock; + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + Run_2Lds(p_a_grid, + p_b_grid, + p_c_grid, + p_b_scale_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp new file mode 100755 index 000000000000..97d0e2a4eb70 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_abd.hpp @@ -0,0 +1,2507 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_as_grid, + karg.p_bs_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + using LDSTypeA = ComputeTypeA; + using LDSTypeB = ComputeTypeB; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumATensor = AsDataType::Size(); + static constexpr index_t NumBTensor = BsDataType::Size(); + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeAsGridPointer() + { + return generate_tuple( + [&](auto i) { + using ADataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeBsGridPointer() + { + return generate_tuple( + [&](auto i) { + using BDataType_ = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using AsGridPointer = decltype(MakeAsGridPointer()); + using BsGridPointer = decltype(MakeBsGridPointer()); + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto + MakeAsGridDescriptor_AK0_M_AK1(const index_t M, + const index_t MPad, + const index_t K, + const index_t KPad, + const std::array& StrideAs, + const index_t AK0) + { + return generate_tuple( + [&](auto i) { + return MakeAGridDescriptor_AK0_M_AK1(M, MPad, K, KPad, StrideAs[i], AK0); + }, + Number{}); + } + + __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + __host__ __device__ static auto + MakeBsGridDescriptor_BK0_N_BK1(const index_t K, + const index_t KPad, + const index_t N, + const index_t NPad, + const std::array& StrideBs, + const index_t BK0) + { + return generate_tuple( + [&](auto i) { + return MakeBGridDescriptor_BK0_N_BK1(K, KPad, N, NPad, StrideBs[i], BK0); + }, + Number{}); + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); }, + Number{}); + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideAs{StrideAs_}, + StrideBs{StrideBs_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + + std::array StrideAs; + std::array StrideBs; + std::array StrideDs; + index_t StrideC; + + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(std::array p_as_grid_, + std::array p_bs_grid_, + std::array p_ds_grid_, + void* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + std::array StrideAs_, + std::array StrideBs_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideAs_, StrideBs_, StrideDs_, StrideC_, k_batch_}, + p_as_grid{}, + p_bs_grid{}, + p_ds_grid{}, + p_c_grid{static_cast(p_c_grid_)}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + + { + // populate pointer, desc for As + static_for<0, NumATensor, 1>{}([&](auto i) { + using ADataType_ = remove_cvref_t>; + + // A pointer + p_as_grid(i) = static_cast(p_as_grid_[i]); + }); + + // populate pointer, desc for Bs + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BDataType_ = remove_cvref_t>; + + // B pointer + p_bs_grid(i) = static_cast(p_bs_grid_[i]); + }); + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + AsGridPointer p_as_grid; + BsGridPointer p_bs_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + +#if 0 + struct SplitKBatchOffsetMultiABD + { + __device__ SplitKBatchOffsetMultiABD(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + Argument& karg) + { + static_for<0, NumATensor, 1>{}([&](auto i) { + using ALayout_ = remove_cvref_t>; + if constexpr(is_same_v) + { + as_k_split_offset[i] = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + as_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideAs[i]; + } + + p_as_grid_(i) = p_as_grid[i] + as_k_split_offset[i]; + }); + + static_for<0, NumBTensor, 1>{}([&](auto i) { + using BLayout_ = remove_cvref_t>; + if constexpr(is_same_v) + { + bs_k_split_offset[i] = blockIdx.z * karg.KRead * karg.StrideBs[i]; + } + else if constexpr(is_same_v) + { + bs_k_split_offset[i] = blockIdx.z * karg.KRead; + } + + p_bs_grid_(i) = p_bs_grid[i] + bs_k_split_offset[i]; + }); + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + AsGridPointer p_as_grid_; + BsGridPointer p_bs_grid_; + std::array as_k_split_offset; + std::array bs_k_split_offset; + }; +#endif + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeA); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeB); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + { + // std::array StrideAs = {problem.StrideA}; + // std::array StrideBs = {problem.StrideB}; + + // AsGridPointer p_as_grid; + // BsGridPointer p_bs_grid; + // DsGridPointer p_ds_grid; + + // const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + // problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + // const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + // problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + +#if 0 + static_for<0, NumDTensor, 1>{}([&](auto j) { + ds_grid_desc_m_n(j) = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs[j]); + }); +#endif + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // const auto a_grid_buf = make_dynamic_buffer( + // p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + // const auto b_grid_buf = make_dynamic_buffer( + // p_bs_grid[I0], b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + +#if 0 + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; +#endif + +#if 0 + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; + +#endif + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_buf, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + +#if 0 + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; +#else + using EDataType = CDataType; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock = + CShuffleBlockTransferScalarPerVector_NPerBlock; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + +#endif + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); +#if 0 + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + +#else + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); +#endif + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + +#if 0 + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } +#else + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } +#endif + }); + } + } + +#if 1 + template + __device__ static void Run_2Lds(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op) + { + // const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + // problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + // const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + // problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0); + const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // const auto a_grid_buf = make_dynamic_buffer( + // p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + // const auto b_grid_buf = make_dynamic_buffer( + // p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + const auto as_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize()); + }, + Number{}); + + const auto bs_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize()); + }, + Number{}); + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + +#if 0 + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_as_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); }, + Number{}); + + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + AsDataType, + Tuple, + decltype(as_grid_desc_ak0_m_ak1), + decltype(tie(a_block_desc_ak0_m_ak1)), + AElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{as_grid_desc_ak0_m_ak1, + idx_as_block_begin, + tie(a_block_desc_ak0_m_ak1), + make_tuple(make_multi_index(0, 0, 0)), + a_element_op}; + +#endif + +#if 0 + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); +#else + const auto idx_bs_block_begin = + generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); }, + Number{}); + + auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + BsDataType, + Tuple, + decltype(bs_grid_desc_bk0_n_bk1), + decltype(tie(b_block_desc_bk0_n_bk1)), + BElementwiseOperation, + Sequence(InMemoryDataOperationEnum::Set)>, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + uniform_sequence_gen_t, + Sequence, + BlockwiseGemmPipe::GlobalBufferNum>{bs_grid_desc_bk0_n_bk1, + idx_bs_block_begin, + tie(b_block_desc_bk0_n_bk1), + make_tuple(make_multi_index(0, 0, 0)), + b_element_op}; +#endif + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(as_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + as_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + bs_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + bs_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + +#if 0 + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; +#else + using EDataType = CDataType; + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + const auto CDEShuffleBlockTransferScalarPerVector_NPerBlock = + CShuffleBlockTransferScalarPerVector_NPerBlock; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + CDEShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)), + c_element_op}; + +#endif + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + +#if 1 + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; +#endif + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + +#if 0 + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } +#else + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } +#endif + }); + } + } +#endif +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp new file mode 100755 index 000000000000..c8dbd81b7395 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp @@ -0,0 +1,2166 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemmMultiD_xdl_cshuffle_v3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume + ((is_same::value || is_same::value) && + KPerBlock < 128 && MPerXdl == 16)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem() = default; + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument() = default; + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead; + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeA); + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeB) < 1 + ? 1 + : 32 * 4 / KPerBlock / sizeof(LDSTypeB); + ; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run_2Lds( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp new file mode 100755 index 000000000000..64fbda7a44e6 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp @@ -0,0 +1,1688 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 +{ + using AScaleType = float; + using BScaleType = float; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + const AScaleType* p_a_scale_grid_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AScaleType* p_a_scale_grid; + const BScaleType* p_b_scale_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return b_lds_block_desc_permuted; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(LDSTypeB) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(LDSTypeB)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(LDSTypeB) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(LDSTypeB))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) + + b_block_space_size_aligned * sizeof(LDSTypeB)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + LDSTypeB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + + a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(MWaves * MPerXdl, 0), + make_multi_index(-MPerBlock, 0), + make_multi_index(-MPerBlock, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // transposed XDL + // // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp new file mode 100755 index 000000000000..3553a1d040c8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -0,0 +1,1988 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && KPerBlock < 128) || + (is_same::value && KPerBlock < 128)) + ? true + : false; + static constexpr auto is_scale_mfma = false; + static constexpr auto mfma = MfmaSelector{}; + static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk); + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma.selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); + static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops(); + static constexpr index_t KPackPerGroup = KPack / KGroup; + static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup; + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMapDefault::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPackPerGroup); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead * NLane; + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline +#if 1 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; + Run_2Lds( + p_a_grid, + p_b_grid, + p_ds_grid, + p_c_grid, + p_shared, + p_shared1, + problem, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + 2>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp new file mode 100755 index 000000000000..909376e5f7d3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp @@ -0,0 +1,2080 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle( + typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds( + typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid, + karg.p_b_grid, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle +{ + using AScaleType = float; + using BScaleType = float; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack / KGroup); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeCGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + const AScaleType* p_a_scale_grid_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AScaleType* p_a_scale_grid; + const BScaleType* p_b_scale_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.M; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.N; + } + else if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead; + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM))); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<1, 0>, + 0, + 1, + 1, + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + true>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(MWaves * MPerXdl, 0), + make_multi_index(-MPerBlock, 0), + make_multi_index(-MPerBlock, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM))); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<1, 0>, + 0, + 1, + 1, + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset, 0)); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + true>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(MWaves * MPerXdl, 0), + make_multi_index(-MPerBlock, 0), + make_multi_index(-MPerBlock, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence> // ThreadTransferDstResetCoordinateAfterRunFlags + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(block_m_id, 0, block_n_id, 0)), + c_element_op}; + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf)); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp new file mode 100755 index 000000000000..ca3902188e4a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -0,0 +1,2399 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" + +namespace ck { + +#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX +#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ enable_if_t +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) +{ +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ enable_if_t +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) +{ +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} +#endif + +template +struct GridwiseGemmMX_xdl_cshuffle_v3 +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base + // KPack in packed data types for pk A/B + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk / + APackedSize); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + constexpr auto permuted_desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MPad, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(MPad), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return a_grid_desc; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + (GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MPadding)), + "f4x2_pk_t does not support K padding"); + static_assert(!((is_same_v, f6x16_pk_t> || + is_same_v, bf6x16_pk_t> || + is_same_v, f6x32_pk_t> || + is_same_v, bf6x32_pk_t>)&&GemmSpec != + GemmSpecialization::Default), + "Packed F6 types do not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple( + make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / BPackedSize, + StrideScaleB_, + StrideC_, + k_batch_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset; + } + } + + // Calculate A scale offset + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * MPerXdl; + + // Calculate B scale offset + b_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * NPerXdl; + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; // New member for scale matrix offset + index_t b_scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // contiguous in LDS + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // contiguous in lds + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return b_lds_block_desc_permuted; + } + else // RowMajor B + { + constexpr auto WaveSize = 64; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } +#if 0 + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + static_assert(is_same_v && + is_same_v, + "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!"); + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] + + // TODO: Document initial thread mapping for more combinations of parameters + + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; + + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); + + Run(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] + + // TODO: Document initial thread mapping for more combinations of parameters + + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; + + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); + + Run_2Lds(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp new file mode 100755 index 000000000000..6691c634842f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -0,0 +1,2345 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" + +namespace ck { + +#ifndef KERNEL_GEMM_XDL_CSHUFFLE_V3_MX +#define KERNEL_GEMM_XDL_CSHUFFLE_V3_MX +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe +template +__global__ enable_if_t +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) +{ +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ enable_if_t +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) +{ +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} +#endif + +template +struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base + // KPack in packed data types for pk A/B + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk / + APackedSize); + + static constexpr index_t NLane = NPerXdl; + static constexpr index_t KLane = 64 / NLane; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + + using ThisThreadBlock = ThisThreadBlock; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + if constexpr(IsXor) + { + constexpr auto permuted_desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + else + { + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor_packed( + make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + GemmSpec != GemmSpecialization::Default), + "f4x2_pk_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple( + make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / BPackedSize, + StrideScaleB_, + StrideC_, + k_batch_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead * NPerXdl; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset; + } + } + + // Calculate A scale offset + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * + MPerXdl / scale_pack_size_a; + + // Calculate B scale offset + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * + NPerXdl / scale_pack_size_b; + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; // New member for scale matrix offset + index_t b_scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // contiguous in LDS + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave/NXdlPack -> NWave -> NXdlPack -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] + + // TODO: Document initial thread mapping for more combinations of parameters + + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; + + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + Run(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = + make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = + make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack); + + // lds max alignment + // constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); + + // dummys + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), // actually the thread desc + Sequence{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<0, 1, 2, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bk0_n_bk1, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] + + // TODO: Document initial thread mapping for more combinations of parameters + + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; + + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + // constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // calculate C grid descriptor + constexpr auto DWORD_BYTES = 4; + constexpr auto atomic_vector_size = DWORD_BYTES / sizeof(CDataType); + + constexpr auto CShuffleBlockTransferClusterLengths = [&]() { + if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set) + { + return CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{}; + } + // Atomic operation + else + { + return generate_sequence_v2( + [&](auto i) { + if constexpr(i == 3) + { + return Number< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} + .At(i) * + CShuffleBlockTransferScalarPerVector_NPerBlock / + atomic_vector_size>{}; + } + else if constexpr(i == 1) + { + return Number< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} + .At(i) / + CShuffleBlockTransferScalarPerVector_NPerBlock * + atomic_vector_size>{}; + } + else + { + return Number< + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock{} + .At(i)>{}; + } + }, + Number<4>{}); + } + }(); + + constexpr auto CShuffleBlockTransferScalarPerVector = [&]() { + if constexpr(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::Set) + { + return CShuffleBlockTransferScalarPerVector_NPerBlock; + } + else + { + return atomic_vector_size; + } + }(); + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + decltype(CShuffleBlockTransferClusterLengths), + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + // We pad the M unconditionaly for Scale + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); + + Run_2Lds(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp new file mode 100755 index 000000000000..67fb4d651e9d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_layernorm_cshuffle_v1.hpp @@ -0,0 +1,1083 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// D = Layernorm(acc_element_op(A * B + broadcast(bias)) + add) * broadcast(gamma) + broadcast(beta) +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_layernorm_xdl_cshuffle_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, // MxN + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const AccElementwiseOperation acc_element_op, + const CElementwiseOperation c_element_op, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock c0_grid_desc_nblock_nperblock, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + // TODO ANT: separate into MMA + Epilogue + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_c0_bias_grid, + p_c0_add_grid, + p_c0_gamma_grid, + p_c0_beta_grid, + p_shared, + a_element_op, + b_element_op, + acc_element_op, + c_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_grid_desc_nblock_nperblock, + block_2_ctile_map); + + // TODO ANT: Run layernorm epilogue here +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_bias_grid; + ignore = p_c0_add_grid; + ignore = p_c0_gamma_grid; + ignore = p_c0_beta_grid; + ignore = a_element_op; + ignore = b_element_op; + ignore = acc_element_op; + ignore = c_element_op; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = c0_grid_desc_nblock_nperblock; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +// The GEMM + Layernorm implementation is a specialized kernel which allows fusing both layers +// together given the condition GEMM extents N of MNK is spanned by a single workgroup. For example, +// a kernel configured with NPerBlock = 128 allows to operate on all GEMM sizes if N <= 128 +template +struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + // LDS allocation for reduction workspace + constexpr index_t c_lds_workspace_size = BlockSize; + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size_aligned * sizeof(FloatCShuffle) + + c_lds_workspace_size * sizeof(FloatReduceAcc)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // in order to reduce N dim without elaborate sync across CUs in single kernel, one + // workgroup must span the entire N extent + if(math::integer_divide_ceil(N, NPerBlock) > 1) + { + return false; + } + + // static check: all waves in the workgroups combined must cover whole N extent in order + // to have efficient N-dim reduction + static_assert(CShuffleNXdlPerWavePerShuffle == NXdlPerWave, + "condition not met for efficient layernorm"); + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // for bias, beta, gamma + __host__ __device__ static constexpr auto + MakeC0GridDescriptor_NBlock_NPerBlock(const C0GridDesc_N& c0_grid_desc_n) + { + const auto N = c0_grid_desc_n.GetLength(I0); + const auto NBlock = N / NPerBlock; + + const auto c0_grid_desc_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_n, + make_tuple(make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1>{})); + + return c0_grid_desc_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using C0GridDescriptor_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC0* __restrict__ p_c0_bias_grid, // 1xN + const FloatC0* __restrict__ p_c0_add_grid, // MxN + const FloatC0* __restrict__ p_c0_gamma_grid, // 1xN + const FloatC0* __restrict__ p_c0_beta_grid, // 1xN + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const AccElementwiseOperation& acc_element_op, + const CElementwiseOperation& c_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const C0GridDescriptor_NBlock_NPerBlock& c0_grid_desc_nblock_nperblock, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_bias_grid_buf = make_dynamic_buffer( + p_c0_bias_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_grid_buf = make_dynamic_buffer( + p_c0_add_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + auto c0_gamma_grid_buf = make_dynamic_buffer( + p_c0_gamma_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + auto c0_beta_grid_buf = make_dynamic_buffer( + p_c0_beta_grid, c0_grid_desc_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatAB, + FloatAB, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = + GridwiseGemmPipeline_Selector(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + c_element_op}; + + const auto NBlock = c0_grid_desc_nblock_nperblock.GetLength(I0); + + // for broadcasting bias, beta, gamma + const auto c0_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c0_grid_desc_nblock_nperblock, + make_tuple(make_insert_transform(I1), + make_insert_transform(I1), + make_pass_through_transform(NBlock), + make_pass_through_transform(NPerBlock)), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + // LDS c_reduce_block_desc_mperblock_nperblock + constexpr auto c_reduce_block_desc_mperblock_nperblock = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I1)), + make_freeze_transform(I0), + make_pass_through_transform( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetLength(I3))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<>{}, Sequence<1>{})); + + static_assert(CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) * + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + BlockSize, + "wrong!"); + + static_assert((CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0) == + 0 && + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) % + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1) == + 0, + "wrong!"); + + constexpr index_t mreduce_per_thread = + (CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I0); + + constexpr index_t nreduce_per_thread = + (CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl) / + CReduceThreadClusterLengths_MPerBlock_NPerBlock::At(I1); + + constexpr auto c_reduce_thread_lengths_mperblock_nperblock = + Sequence{}; + + // pytorch default + // https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html + static constexpr FloatReduceAcc epsilon = 1e-5; + + // VGPR c_reduce_thread_desc_mperblock_nperblock + constexpr auto c_reduce_thread_desc_mperblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_reduce_thread_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, Number{})); + + // VGPR d_reduce_thread_desc_mperblock + constexpr auto d_reduce_thread_desc_mperblock = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + // TODO: this should be implemented as a blockwise reduction + auto c_reduce_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + auto c0_thread_buf = make_static_buffer( + c_reduce_thread_desc_mperblock_nperblock.GetElementSpaceSize()); + + // Align 16 bytes (maximum LDS read/write width) + constexpr auto c_block_size_aligned = + math::integer_least_multiple( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize() * + sizeof(FloatCShuffle), + 16) / + sizeof(FloatCShuffle); + + auto d_reduce_work_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + c_block_size_aligned), + BlockSize); + + // Sum thread workspace + auto d0_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // Squared sum thread workspace + auto d1_thread_buf = make_static_buffer( + d_reduce_thread_desc_mperblock.GetElementSpaceSize()); + + // reduce: threadwise copy from LDS to VGPR + constexpr auto c_reduce_thread_cluster_desc = make_cluster_descriptor( + CReduceThreadClusterLengths_MPerBlock_NPerBlock{}, Sequence<1, 0>{}); + + const auto c_reduce_thread_cluster_idx = + c_reduce_thread_cluster_desc.CalculateBottomIndex( + make_multi_index(get_thread_local_1d_id())); + + const auto c_reduce_thread_data_idx_begin = + c_reduce_thread_cluster_idx * c_reduce_thread_lengths_mperblock_nperblock; + + auto c_reduce_thread_copy_lds_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatCShuffle, + FloatReduceAcc, + decltype(c_reduce_block_desc_mperblock_nperblock), + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, c_reduce_thread_data_idx_begin}; + + auto c_reduce_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatReduceAcc, + FloatCShuffle, + decltype(c_reduce_thread_desc_mperblock_nperblock), + decltype(c_reduce_block_desc_mperblock_nperblock), + tensor_operation::element_wise::PassThrough, + decltype(c_reduce_thread_lengths_mperblock_nperblock), + Sequence<0, 1>, + 1, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + InMemoryDataOperationEnum::Set, + 1, + true>{c_reduce_block_desc_mperblock_nperblock, + c_reduce_thread_data_idx_begin, + tensor_operation::element_wise::PassThrough{}}; + + auto c0_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c0_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c0_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // Note: c0_add is of same layout as c so we don't declare new c0_add_desc here + auto c0_add_thread_copy_global_to_vgpr = ThreadwiseTensorSliceTransfer_v2< + FloatC0, + FloatC0, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(c_reduce_thread_desc_mblock_mperblock_nblock_nperblock), + Sequence, + Sequence<0, 1, 2, 3>, + 3, + CReduceThreadCopySrcDstScalarPerVector_NPerBlock, + 1, + true>(c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], + c_reduce_thread_data_idx_begin[I0], + block_work_idx[I1], + c_reduce_thread_data_idx_begin[I1])); + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + block_sync_lds(); + + // load from LDS and global, add bias + c_reduce_thread_copy_lds_to_vgpr.Run(c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf, + c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_bias_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + FloatReduceAcc out; + acc_element_op(out, + c_reduce_thread_buf(i) + + static_cast(c0_thread_buf(i))); + c_reduce_thread_buf(i) = out; // acc_element_op(acc + bias) + }); + + c0_add_thread_copy_global_to_vgpr.Run( + c_grid_desc_mblock_mperblock_nblock_nperblock, + c0_add_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // add + }); + + // layernorm + { + using ThreadwiseReduceD0 = + ThreadwiseReduction; + using ThreadwiseReduceD1 = + ThreadwiseReduction; + + const auto d0_zeroVal = + ThreadwiseReduceD0::Op::template GetIdentityValue(); + const auto d1_zeroVal = + ThreadwiseReduceD1::Op::template GetIdentityValue(); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); + static_for<0, mreduce_per_thread, 1>{}( + [&](auto i) { d1_thread_buf(i) = d1_zeroVal; }); + + // reduce sum in VGPR + ThreadwiseReduceD0::Reduce(c_reduce_thread_buf, d0_thread_buf); + + // reduce squared sum in VGPR + ThreadwiseReduceD1::Reduce(c_reduce_thread_buf, d1_thread_buf); + + // reduce within workgroup + using BlockwiseReduce = PartitionedBlockwiseReduction< + FloatReduceAcc, + BlockSize, + CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K + Sequence<1, 0>, // ThreadClusterArrangeOrder + reduce::Add, + false>; + + static_for<0, mreduce_per_thread, 1>{}([&](auto i) { + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d0_thread_buf(i)); // blockwise reduced sum + block_sync_lds(); + BlockwiseReduce::Reduce(d_reduce_work_buf, + d1_thread_buf(i)); // blockwise reduced squared sum + }); + + // normalize + const index_t NRaw = + c_grid_desc_mblock_mperblock_nblock_nperblock.GetTransforms()[I0] + .GetUpperLengths()[I1]; // TODO: proper handle + + static_for<0, mreduce_per_thread, 1>{}([&](auto im) { + static_for<0, nreduce_per_thread, 1>{}([&](auto in) { + constexpr auto dst_offset = + Number{}; + + constexpr auto src_offset = + Number{}; + + FloatReduceAcc avg_sum = d0_thread_buf(src_offset) / NRaw; + FloatReduceAcc avg_squared_sum = d1_thread_buf(src_offset) / NRaw; + + FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; + FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; + FloatReduceAcc divisor_sqrt; + tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor); + + c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; + }); + }); + + // scaling + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_gamma_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) *= + static_cast(c0_thread_buf(i)); // * gamma + }); + + c0_thread_copy_global_to_vgpr.Run( + c0_grid_desc_mblock_mperblock_nblock_nperblock, + c0_beta_grid_buf, + c_reduce_thread_desc_mblock_mperblock_nblock_nperblock, + make_tuple(I0, I0, I0, I0), + c0_thread_buf); + + static_for<0, c_reduce_thread_desc_mperblock_nperblock.GetElementSize(), 1>{}( + [&](auto i) { + c_reduce_thread_buf(i) += + static_cast(c0_thread_buf(i)); // + beta + }); + + block_sync_lds(); + + c_reduce_thread_copy_vgpr_to_lds.Run(c_reduce_thread_desc_mperblock_nperblock, + make_tuple(I0, I0), + c_reduce_thread_buf, + c_reduce_block_desc_mperblock_nperblock, + c_shuffle_block_buf); + + } // end layernorm + + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0 + c0_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c0_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + + // move on C0_add + c0_add_thread_copy_global_to_vgpr.MoveSrcSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp new file mode 100755 index 000000000000..50363d832ec5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_waveletmodel_cshuffle.hpp @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_waveletmodel.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + static constexpr auto AK0PerBlock = Number{}; + static constexpr auto BK0PerBlock = Number{}; + + struct TileLoadThreadGroup + { + __device__ static constexpr index_t GetNumOfThread() { return TileLoadThreadGroupSize; } + + __device__ static constexpr bool IsBelong() + { + return (get_thread_local_1d_id() >= TileLoadThreadGroupSize); + } + + __device__ static index_t GetThreadId() + { + return get_thread_local_1d_id() - TileMathThreadGroupSize; + } + }; + + struct TileMathThreadGroup + { + __device__ static constexpr index_t GetNumOfThread() { return TileMathThreadGroupSize; } + + __device__ static constexpr bool IsBelong() + { + return get_thread_local_1d_id() < TileMathThreadGroupSize; + } + + __device__ static index_t GetThreadId() { return get_thread_local_1d_id(); } + }; + + using CShuffleBlockTransferThreadGroup = ThisThreadBlock; + + // load and math+store Wave pipelines. + // TODO: build pipelines blocks scheduling parallel tasks + using GridwiseGemmLoad = GridwiseGemmLoadWave; + using GridwiseGemmMath = GridwiseGemmMathWave; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(AK0PerBlock, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(BK0PerBlock, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + + __host__ __device__ static constexpr auto + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(ABDataType), + c_block_size * sizeof(EDataTypeShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k, + const BGridDesc_N_K& b_grid_desc_n_k, + const EGridDesc_M_N& e_grid_desc_m_n, + const Block2ETileMap& /*block_2_etile_map*/) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + // check consistency of desc + if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && + K == b_grid_desc_n_k.GetLength(I1))) + { + return false; + } + + // check tile size + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + { + return false; + } + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmMath::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + + // check tensor size: cannot be larger than 2GB each + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + b_grid_desc_n_k.GetElementSpaceSize() * sizeof(ABDataType) <= TwoGB && + e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB)) + { + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmMath::CalculateHasMainLoop(num_loop); + } + + // return block_id to E matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + constexpr auto M01 = I1; + constexpr auto N01 = I1; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M0, M01)), + make_unmerge_transform(make_tuple(N0, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, N0, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const EGridDesc_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + // A desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // B desc for source in blockwise copy + __host__ __device__ static constexpr auto + MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // E desc for destination in blockwise copy + template + __host__ __device__ static constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const EGridDescriptor_M_N& e_grid_desc_m_n) + { + const auto M = e_grid_desc_m_n.GetLength(I0); + const auto N = e_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + e_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return e_grid_desc_mblock_mperblock_nblock_nperblock; + } + + using EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = + remove_cvref_t; + + using DefaultBlock2ETileMap = + remove_cvref_t; + + template + __device__ static void Run(const ABDataType* __restrict__ p_a_grid, + const ABDataType* __restrict__ p_b_grid, + EDataType* __restrict__ p_e_grid, + void* __restrict__ p_shared, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const EElementwiseOperation& e_element_op, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock& + e_grid_desc_mblock_mperblock_nblock_nperblock, + const Block2ETileMap& block_2_etile_map) + { + // build loadWave and MathWave pipelines + // loadWave and MathWave synchronized through LDS + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + if(TileLoadThreadGroup::IsBelong()) + { + + // LoadWave + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + ABDataType, + ABDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + GridwiseGemmLoad::template RunLoadWavePipeline( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + num_k_block_main_loop); + + block_sync_lds(); + block_sync_lds(); + } + else if(TileMathThreadGroup::IsBelong()) + { + // branch early for math wave + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< + TileMathThreadGroupSize, + ABDataType, + ABDataType, + FloatGemmAcc, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack>{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + auto c_grid_buf = make_dynamic_buffer( + p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // TODO re-architect LDS+math stages + // Writing data to GMEM: only math wave is doing the work in cshuffle + GridwiseGemmMath::template RunMathWavePipeline( + a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatGemmAcc, + EDataTypeShuffle, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + CShuffleBlockTransferThreadGroup, // ThreadGroup + EElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + EDataTypeShuffle, // typename SrcData, + EDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + e_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), + e_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + // Different way of getting coalesced writes: + // We can get rid of doing cshuffle. Instead of reading A rows in contiguous manner + // do it interleaved, then mfma can have nice c-mat layout as below: + // + // TODO + // We do not need to do LDS swizzle to align global writes writing cache lines: + // v_mfma cmat, amat, bmat, cmat - c-mat register layout are 1xN + // elments (N is vertical or strided + // dimension) + // v_mfma cmat, bmat, amat, cmat - c-mat register layout are Mx1 + // elments (M is coalescing + // dimension) by enumerating M index in + // amat, bmat you can align cmat + // register(s) to contiguous M elements + // for example + // 1st mfma instruction output space : 0 4 8 12 16 .... + // 2nd mfma instruction output space : 1 5 9 13 17 .... + // 3rd mfma instruction output space : 2 6 10 14 18 .... + // 4th mfma instruction output space : 3 7 11 15 19 .... + // you can pack 4 registers output space into 2WORD and do global write + // (no LDS swizzling required) + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + e_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp new file mode 100755 index 000000000000..b7947309e432 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp @@ -0,0 +1,1012 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v4_no_carry +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v4_no_carry() = default; + + __host__ __device__ constexpr Merge_v4_no_carry(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff& idx_up_diff, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_up_diff[I0]; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod_wrw, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + +template +__host__ __device__ constexpr auto make_merge_transform_v4_no_carry(const LowLengths& low_lengths) +{ + return Merge_v4_no_carry{low_lengths}; +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_bwd_weight(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_grid_desc_mblock_mperblock_nblock_nperblock, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // FloatAAdjusted -> ComputeTypeA, FloatBAdjusted -> ComputeTypeB, + // throughout this file +#if CK_GFX90A_DENORM_WORKAROUND + using FloatAAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeA>; + using FloatBAdjusted = + conditional_t, ck::bhalf_t, ComputeTypeB>; +#else + using FloatAAdjusted = ComputeTypeA; + using FloatBAdjusted = ComputeTypeB; +#endif + + // M0/M1/M1Padding + static constexpr auto M1PerBlock = Number{}; + static constexpr auto M0PerBlock = Number{}; + static constexpr auto M1Padding = Number{}; + + // N0/N1/N1Padding + static constexpr auto N1PerBlock = Number{}; + static constexpr auto N0PerBlock = Number{}; + static constexpr auto N1Padding = Number{}; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_block_desc_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_b_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + if constexpr(ABlockLdsExtraM1Wrw) + { + constexpr auto a_block_desc_b_k0_m0_m1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + M1Padding), + Number{} * (Number{} * K1 + M1Padding), + Number{} * K1 + M1Padding, + K1, + I1)); + + constexpr auto a_block_desc_b_k0_m_k1_tmp = transform_tensor_descriptor( + a_block_desc_b_k0_m0_m1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return a_block_desc_b_k0_m_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return a_block_desc_b_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple( + Number{}, Number{}, Number{}, K1), + make_tuple(Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number{}), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_block_desc_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_b_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + if constexpr(BBlockLdsExtraN1Wrw) + { + constexpr auto b_block_desc_b_k0_n0_n1_k1 = make_naive_tensor_descriptor( + make_tuple(Number<1>{}, + Number{}, + Number{}, + Number{}, + K1), + make_tuple(Number{} * Number{} * + (Number{} * K1 + N1Padding), + Number{} * (Number{} * K1 + N1Padding), + Number{} * K1 + N1Padding, + K1, + I1)); + + constexpr auto b_block_desc_b_k0_n_k1_tmp = transform_tensor_descriptor( + b_block_desc_b_k0_n0_n1_k1, + make_tuple(make_pass_through_transform(Number<1>{}), + make_pass_through_transform(Number{}), + make_merge_transform_v4_no_carry( + make_tuple(Number{}, Number{})), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return b_block_desc_b_k0_n_k1_tmp; + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return b_block_desc_b_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = math::integer_least_multiple( + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = math::integer_least_multiple( + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max((a_block_space_size * sizeof(FloatAAdjusted) + + b_block_space_size * sizeof(FloatBAdjusted)), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + // const bool has_main_k0_block_loop = K0 > K0PerBlock; + const index_t num_loop = K0 / K0PerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + + // return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CMNGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N00_M01_N01( + c_m_n_grid_desc, M01, N01, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = + decltype(MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatA* __restrict__ p_a_grid, + const FloatB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_B_K0_M_K1& a_b_k0_m_k1_grid_desc, + const BGridDesc_B_K0_N_K1& b_b_k0_n_k1_grid_desc, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto a_b_k0_m_k1_block_desc = GetABlockDescriptor_Batch_K0PerBlock_MPerBlock_K1(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto b_b_k0_n_k1_block_desc = GetBBlockDescriptor_Batch_K0PerBlock_NPerBlock_K1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + FloatAAdjusted, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + FloatBAdjusted, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + K1 <= 4) || + (is_same::value && K1 <= 8) || + ((is_same::value || is_same::value) && + K1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t KPack = math::max(K1, + MfmaSelector::selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_k0_m_k1_block_desc.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size, + b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + static_assert(M1 == MWave, ""); + static_assert(N1 == NWave, ""); + static_assert(M2 * M3 * M4 == MPerXDL, ""); + static_assert(N2 == NPerXDL, ""); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; // namespace ck + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp new file mode 100755 index 000000000000..7c401a49572a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp @@ -0,0 +1,679 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_skip_b_lds_v1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3; + ignore = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + static constexpr index_t WaveSize = 64; + static constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXDL); + + static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr index_t K0PerThread = K0PerBlock / xdlops_gemm.K0PerXdlops; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), + max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // 2-stage prefetch currently only support even number of K0 loop + // TODO: add support for odd number of K0 loop + if(!((K0 / K0PerBlock) % BBlockBufferSize == 0)) + { + return false; + } + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock); + + return grid_size; + } + + // TODO move this function into GEMM-pipeline class + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / (BBlockBufferSize * K0PerBlock)) > 1; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1) + { + const auto K0 = b_grid_desc_k0_n_k1.GetLength(I0); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + + const auto b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1 = transform_tensor_descriptor( + b_grid_desc_k0_n_k1, + make_tuple(make_unmerge_transform( + make_tuple(K0 / K0PerBlock, xdlops_gemm.K0PerXdlops, K0PerThread)), + make_unmerge_transform(make_tuple( + N / (NXdlPerWave * NWaves * NPerXDL), NXdlPerWave, NWaves, NPerXDL)), + make_pass_through_transform(K1)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5, 6>{}, Sequence<7>{})); + return b_griddesc_k0_nblockid_nrepeat_waves_nperxdlops_k1; + } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto GetWaveKNIdx(const index_t thread_id) + { + constexpr auto wave_threadid_to_nk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(xdlops_gemm.K0PerXdlops, NPerXDL))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + return wave_threadid_to_nk_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix threadwise copy + constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + I1, + Number{}, // K0PerThread + I1, // NBlockId + Number{}, // repeat + I1, // waves + I1, // NPerXdlops + Number{})); + + using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_block_desc_k0_m_k1), + decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3), + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + K1>; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto + MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n, index_t M01, index_t N01) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto cblockid_to_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + cblockid_to_m00_m01_n00_n01_block_cluster_adaptor); + + return cblockid_to_m0_n0_block_cluster_adaptor; + } + + using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = + decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); + using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1)); + using BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 = + decltype(MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3(BGridDesc_K0_N_K1{})); + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3 b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2& c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + 1>(a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + ignore = b_element_op; + // B matrix threadwise copy + constexpr auto b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + I1, + Number{}, // K0PerThread + I1, // NBlockId + Number{}, // repeat + I1, // waves + I1, // NPerXdlops + Number{})); + + auto b_thread_buf = generate_tuple( + [&](auto i) { + ignore = i; + return StaticBuffer{}; + }, + Number{}); + + const auto wave_id = GetWaveIdx(); + const auto wave_k_n_id = GetWaveKNIdx(wave_id[I2]); + +#if 0 + const index_t block_id = get_block_1d_id(); + const index_t thread_id = get_thread_local_1d_id(); + printf("block id: %d m blockid: %d n block id: %d ,thread id: %d, wave id :{%d %d %d} " + "kn id: {%d %d}\n", + block_id, + block_work_idx[I0], + block_work_idx[I1], + thread_id, + wave_id[I0], + wave_id[I1], + wave_id[I2], + wave_k_n_id[I0], + wave_k_n_id[I1]); + printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t", + xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0)); +#endif + + auto b_threadwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + I1, + I1, + Number{}>, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + make_multi_index( + 0, wave_k_n_id[I0], 0, block_work_idx[I1], 0, wave_id[I1], wave_k_n_id[I1], 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1< + BlockSize, + FloatAB, + FloatAcc, + decltype(a_block_desc_k0_m_k1), + decltype(b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3), + MPerBlock, + NPerBlock, + K0PerBlock, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + K1>{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + // gridwise GEMM pipeline + constexpr auto a_block_slice_copy_step = + make_multi_index(K0PerBlock * BBlockBufferSize, 0, 0); + constexpr auto b_thread_slice_copy_step = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); + // preload data to regiester and LDS + { + // Read + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, a_block_slice_copy_step); + + static_for<0, BBlockBufferSize, 1>{}([&](auto ii) { + b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_grid_buf, + b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_buf(Number{})); + b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_thread_slice_copy_step); + }); + + // Initialize C + c_thread_buf.Clear(); + // a data write to lds + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + // main body + if constexpr(HasMainK0BlockLoop) + { + index_t K0BlockMainLoop = + __builtin_amdgcn_readfirstlane(K0 / (BBlockBufferSize * K0PerBlock)); + index_t i = 0; + do + { + a_blockwise_copy.RunRead(a_grid_desc_k0_m_k1, a_grid_buf); + blockwise_gemm.ResetABlockStartWindow(); + block_sync_lds(); + + static_for<0, BBlockBufferSize, 1>{}([&](auto ii) { + blockwise_gemm.Run(a_block_buf, b_thread_buf(Number{}), c_thread_buf); + blockwise_gemm.MoveABlockSliceWindow(); + s_nop(); + + b_threadwise_copy.Run(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_grid_buf, + b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + b_thread_buf(Number{})); + b_threadwise_copy.MoveSrcSliceWindow(b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3, + b_thread_slice_copy_step); + }); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc_k0_m_k1, a_block_buf); + // move a and b window + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m_k1, + a_block_slice_copy_step); + + i += 1; + } while(i < (K0BlockMainLoop - 1)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.ResetABlockStartWindow(); + + static_for<0, BBlockBufferSize, 1>{}([&](auto ii) { + blockwise_gemm.Run(a_block_buf, b_thread_buf(Number{}), c_thread_buf); + blockwise_gemm.MoveABlockSliceWindow(); + }); + } + } + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp new file mode 100755 index 000000000000..3e23008a5ff0 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_splitk_lds_direct_load.hpp @@ -0,0 +1,972 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_lds.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_splitk_lds_direct_load(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); +#else + ignore = karg; + ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemm_xdlops_splitk_lds_direct_load +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto KPerBlock = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0Padded; + index_t k_batch; + + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0Padded(K0Padded_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0Padded:" << K0Padded << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max(NumGemmKPrefetchStage * (a_block_space_size + b_block_space_size) * + sizeof(ComputeType), + c_block_size * sizeof(FloatC)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + return false; + } + } + + const auto num_k_loop = karg.K0Padded / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + return true; + } + + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; + return KPad; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) + { + const index_t num_loop = K0Padded / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; + + template + __device__ static void Run(const Argument& karg, + void* __restrict__ p_shared_block, + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + // Elementwise operations are not supported for A and B, arguments left only for the API + // consistency. + (void)a_element_op; + (void)b_element_op; + + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [KBatch, M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(K1, Number{}, I1)); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple( + Number{} * Number{}, K1, Number{}, I1)); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(K1, Number{}, I1)); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple( + Number{} * Number{}, K1, Number{}, I1)); + } + }(); + + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferSrcAccessOrder, + FloatA, + ComputeType, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferSrcAccessOrder, + FloatB, + ComputeType, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + ComputeType, // ComputeType A + ComputeType, // ComputeType B + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + const auto a_buffers_offset = 0; + auto a_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + a_b_k0_m_k1_block_desc.GetElementSpaceSize(), + a_buffers_offset, + max_lds_align); + const auto b_buffers_offset = a_block_space_size * NumGemmKPrefetchStage; + auto b_block_buffers = + ck::lds_utils::AllocateLdsBuffers( + p_shared_block, + b_b_k0_n_k1_block_desc.GetElementSpaceSize(), + b_buffers_offset, + max_lds_align); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buffers, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buffers, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp new file mode 100755 index 000000000000..e9190dee2920 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp @@ -0,0 +1,1185 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v3.hpp" +#include "ck/utility/workgroup_barrier.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB* p_a_grid, + const typename GridwiseGemm::FloatAB* p_b_grid, + typename GridwiseGemm::FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + typename GridwiseGemm::Block2CTileMap block_mapping) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_workspace, + M, + N, + K, + StrideA, + StrideB, + StrideC, + block_mapping, + static_cast(p_shared)); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_workspace; + ignore = M; + ignore = N; + ignore = K; + ignore = StrideA; + ignore = StrideB; + ignore = StrideC; + ignore = block_mapping; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + static constexpr auto KPerBlock = K0PerBlock * K1; + + using ThisThreadBlock = ThisThreadBlock; + using FloatAcc = FloatAcc_; + using FloatCShuffle = FloatAcc; + + using Block2CTileMap = Block2CTileMap_; + using FloatAB = FloatAB_; + using FloatC = FloatC_; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + Block2CTileMap block_mapping; + + Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + uint32_t num_cu, + uint32_t occupancy, + uint32_t num_sk_blocks_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + block_mapping(M, N, K, num_cu, occupancy, num_sk_blocks_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + __host__ __device__ static auto CalculateK0(index_t KPad) { return KPad / K1; } + + __host__ __device__ static auto + MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA) + { + const index_t K0 = CalculateK0(KPad); + + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor(a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __host__ __device__ static auto + MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB) + { + const index_t K0 = CalculateK0(KPad); + + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor(b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle().GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + return false; + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + return false; + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + return false; + } + + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + } + + __host__ __device__ static constexpr auto GetClusterLengthReduction() + { + // TODO: assume C is row major + // TODO: we always first loop over N, then M + constexpr auto NPerBlockPow2 = math::next_power_of_two(); + constexpr auto NPerBlockReduction = + NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL; + constexpr auto MPerBlockReduction = + (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction; + return Sequence{}; + } + + __host__ __device__ static constexpr auto GetPartialAccBlockDescriptor() + { + const auto c_partial_acc_block_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(NPerBlock, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MPerBlock, NPerBlock), + make_tuple(I1, MPerBlock)); + } + }(); + return c_partial_acc_block_m_n; + } + + using CGridDesc_M_N = remove_cvref_t; + + __device__ static void Run(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + void* p_workspace, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + Block2CTileMap block_mapping, + void* __restrict__ p_shared_block) + { + uint32_t m = M; + uint32_t n = N; + uint32_t k = K; + uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock; + uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock; + uint32_t pad_k = (k + KPerBlock - 1) / KPerBlock * KPerBlock; + uint32_t stride_a = StrideA; + uint32_t stride_b = StrideB; + uint32_t stride_c = StrideC; + + const auto a_k0_m_k1_grid_desc = MakeAGridDescriptor_K0_M_K1(m, pad_m, k, pad_k, stride_a); + const auto b_k0_n_k1_grid_desc = MakeBGridDescriptor_K0_N_K1(k, pad_k, n, pad_n, stride_b); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(m, pad_m, n, pad_n, stride_c); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + const AElementwiseOperation a_element_op = AElementwiseOperation{}; + const BElementwiseOperation b_element_op = BElementwiseOperation{}; + const CElementwiseOperation c_element_op = CElementwiseOperation{}; + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = static_cast(p_shared_block); + FloatAB* p_b_block = static_cast(p_shared_block) + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_v3(); + + uint32_t block_idx = block_mapping.get_block_idx(); + bool is_sk_block = block_idx < block_mapping.sk_num_blocks; + bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx && + block_idx < block_mapping.reduction_start_block_idx; + bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx; + bool is_padding_block = block_idx >= block_mapping.sk_num_blocks && + block_idx < block_mapping.dp_start_block_idx; + uint32_t iter_start, iter_end; + block_mapping.get_block_itr(block_idx, iter_start, iter_end); + uint32_t total_iter_length = iter_end - iter_start; + + if(is_padding_block) + return; + + uint32_t* p_semaphore = + reinterpret_cast(reinterpret_cast(p_workspace) + + block_mapping.get_workspace_size_for_acc(sizeof(FloatAcc))); + + if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) + { + if(is_reduction_block) + { + // descriptors + constexpr auto cluster_length_reduce = GetClusterLengthReduction(); + constexpr auto reduce_desc = make_cluster_descriptor(cluster_length_reduce); + const auto reduce_thread_cluster_idx = + reduce_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); + const auto thread_m_cluster_id = reduce_thread_cluster_idx[I0]; + const auto thread_n_cluster_id = reduce_thread_cluster_idx[I1]; + + constexpr auto MReduceIters = + math::integer_divide_ceil(Number{}, cluster_length_reduce.At(I0)); + constexpr auto NReduceIters = math::integer_divide_ceil( + Number{}, + cluster_length_reduce.At(I1) * + Number{}); + + constexpr auto acc_thread_buf_load_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{})); + constexpr auto acc_thread_buf_store_desc = make_naive_tensor_descriptor_packed( + make_tuple(I1, I1, I1, Number{})); + + constexpr auto c_partial_acc_block_m_n = GetPartialAccBlockDescriptor(); + + constexpr auto partial_acc_load_step_n = make_multi_index( + 0, cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_load_step_n_reverse = + make_multi_index(0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_load_step_m = + make_multi_index(cluster_length_reduce.At(I0), 0); + + constexpr auto partial_acc_store_step_n = make_multi_index( + 0, + 0, + 0, + cluster_length_reduce.At(I1) * CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_store_step_n_reverse = + make_multi_index(0, + 0, + 0, + -1 * cluster_length_reduce.At(I1).value * (NReduceIters - 1) * + CBlockTransferScalarPerVector_NWaveNPerXDL); + constexpr auto partial_acc_store_step_m = + make_multi_index(0, cluster_length_reduce.At(I0), 0, 0); + + StaticBuffer + parcial_acc_buf; + StaticBuffer + acc_buf; + + // start to compute + auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx; + auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n); + + workgroup_barrier wg_barrier(p_semaphore); + + uint32_t tile_acc_offset_start = + block_mapping.get_acc_buffer_offset_from_tile(reduction_idx); + uint32_t tile_acc_offset_end = + block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1); + + auto acc_load = ThreadwiseTensorSliceTransfer_v2< + FloatAcc, // SrcData, + FloatAcc, // DstData, + decltype(c_partial_acc_block_m_n), // SrcDesc, + decltype(acc_thread_buf_load_desc), // DstDesc, + Sequence<1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, + Sequence<0, 1>, // DimAccessOrder, + 1, // SrcVectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // SrcScalarPerVector, + 1, // SrcScalarStrideInVector, + false // SrcResetCoordinateAfterRun, + >{c_partial_acc_block_m_n, + make_multi_index(thread_m_cluster_id, + thread_n_cluster_id * + CBlockTransferScalarPerVector_NWaveNPerXDL)}; + + auto acc_store = ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, // SrcData, + FloatC, // DstData, + decltype(acc_thread_buf_store_desc), // SrcDesc, + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), // DstDesc, + CElementwiseOperation, // ElementwiseOperation, + Sequence<1, 1, 1, CBlockTransferScalarPerVector_NWaveNPerXDL>, // SliceLengths, + Sequence<0, 1, 2, 3>, // DimAccessOrder, + 3, // DstVectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // DstScalarPerVector, + InMemoryDataOperationEnum::Set, // InMemoryDataOperationEnum DstInMemOp, + 1, // DstScalarStrideInVector, + false // DstResetCoordinateAfterRun, + >{c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + thread_m_cluster_id, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + thread_n_cluster_id * + CBlockTransferScalarPerVector_NWaveNPerXDL), + CElementwiseOperation{}}; + + // block synchronization + wg_barrier.wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start); + +#if 0 + if(threadIdx.x == 0) { + printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast(blockIdx.x), + reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end), + __builtin_amdgcn_readfirstlane(spatial_idx[I0]), + __builtin_amdgcn_readfirstlane(spatial_idx[I1])); + } +#endif + + using Accumulation = ck::detail:: + AccumulateWithNanCheck; + + for(int i_m = 0; i_m < MReduceIters; i_m++) + { + static_for<0, NReduceIters, 1>{}([&](auto i_n_reduce) { + acc_buf.Clear(); + for(auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++) + { + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + + i * c_partial_acc_block_m_n.GetElementSpaceSize(), + c_partial_acc_block_m_n.GetElementSpaceSize()); + + acc_load.Run(c_partial_acc_block_m_n, + c_partial_acc_buf, + acc_thread_buf_load_desc, + make_tuple(I0, I0), + parcial_acc_buf); + + static_for<0, CBlockTransferScalarPerVector_NWaveNPerXDL, 1>{}( + [&](auto i_vec) { + constexpr auto offset = + acc_thread_buf_load_desc.CalculateOffset( + make_tuple(0, i_vec)); + Accumulation::Calculate(acc_buf(Number{}), + parcial_acc_buf[Number{}]); + }); + } + + if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL < + NPerBlock) + { + acc_store.Run(acc_thread_buf_store_desc, + make_tuple(I0, I0, I0, I0), + acc_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + if constexpr(NReduceIters != 1) + { + if constexpr(i_n_reduce != (NReduceIters - 1)) + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n); + } + else + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_n_reverse); + acc_store.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_n_reverse); + } + } + }); + { + acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n, + partial_acc_load_step_m); + acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock, + partial_acc_store_step_m); + } + } + return; + } + } + + // offset for last acc buffer of this block + uint32_t block_acc_offset = + (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock * + NPerBlock; + + while(true) + { + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length)); + uint32_t tile_idx, iter_offset; + block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset); + iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1); + auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(spatial_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(spatial_idx[I1] * NPerBlock); + + const index_t k0_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_k0_m_k1_grid_desc), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_k0_m_k1_grid_desc, + make_multi_index(k0_block_data_idx_on_grid, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_k0_n_k1_grid_desc), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_k0_n_k1_grid_desc, + make_multi_index(k0_block_data_idx_on_grid, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + const index_t num_k_block_main_loop = current_iter_length; + + gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_k0_n_k1_grid_desc, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mpershuffle_nblock_npershuffle = + GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle(); + + constexpr auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle = + GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle(); + + auto c_block_buf = make_dynamic_buffer( + reinterpret_cast(p_shared_block), + c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize()); + + auto c_partial_acc_buf = + make_dynamic_buffer( + reinterpret_cast(p_workspace) + block_acc_offset, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_unmerge_transform( + make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform( + make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + FloatAcc, + FloatCShuffle, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(__builtin_amdgcn_readfirstlane(spatial_idx[I0]), + 0, + __builtin_amdgcn_readfirstlane(spatial_idx[I1]), + 0), + c_element_op}; + + // LDS to global partial acc + auto c_block_copy_lds_to_partial_acc = ThreadGroupTensorSliceTransfer_v6r1r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + // InMemoryDataOperationEnum::Set, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatCShuffle, // typename DstData, + decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle), + decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + false, // bool ThreadTransferSrcResetCoordinateAfterRun, => need to be false, + // othre wise has scratch + false> // bool ThreadTransferDstResetCoordinateAfterRun, => need to be false, + // othre wise has scratch + {c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_multi_index(0, 0, 0, 0), + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_multi_index(0, 0, 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + c_block_copy_lds_to_global.SetSrcSliceOrigin( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(0, 0, 0, 0)); + + // LDS to global + if(is_dp_block) + c_block_copy_lds_to_global.template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + else if(is_sk_block) + { + if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + // constexpr offset + c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + make_tuple(0, 0, 0, 0)); + + c_block_copy_lds_to_partial_acc.SetDstSliceOrigin( + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0)); + + c_block_copy_lds_to_partial_acc + .template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle, + c_partial_acc_buf); + } + else if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Atomic) + { + c_block_copy_lds_to_global + .template Run( + c_block_desc_mblock_mpershuffle_nblock_npershuffle, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + } + } + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + mxdlperwave_forward_step); + } + }); + + if constexpr(Block2CTileMap::ReductionStrategy == + StreamKReductionStrategy::Reduction) + { + if(is_sk_block) + { + // increase the counter for this tile + workgroup_barrier wg_barrier(p_semaphore); + wg_barrier.inc(tile_idx); + } + } + } + + // exit condition + iter_end -= current_iter_length; + if(iter_end <= iter_start) + break; + + if constexpr(Block2CTileMap::ReductionStrategy == StreamKReductionStrategy::Reduction) + { + block_acc_offset -= MPerBlock * NPerBlock; + } + // make sure next loop LDS is ready for use + block_sync_lds(); + } + } + + template + struct LStr + { + static std::string Get() { return ""; } + }; + + template <> + struct LStr + { + static std::string Get() { return "R"; } + }; + + template <> + struct LStr + { + static std::string Get() { return "C"; } + }; + + static std::string GetTypeString() + { + auto str = std::stringstream(); + + // clang-format off + str << "GemmXdlStreamK_" + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << "_" + << "B" << BlockSize << "_" + << "Vec" << ABlockTransferSrcScalarPerVector << "x" + << BBlockTransferSrcScalarPerVector << "x" + << CBlockTransferScalarPerVector_NWaveNPerXDL << "_" + << MPerBlock << "x" + << NPerBlock << "x" + << K0PerBlock << "x" + << K1 ; + // clang-format on + + return str.str(); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp new file mode 100755 index 000000000000..5c3d9b7ba4bc --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -0,0 +1,1053 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDesc_M_N c_grid_desc_m_n) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_m_n; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif +#if CK_USE_WAVES_PER_EU + __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU))) +#endif + kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const auto a_grid_desc_k0_m_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA)); + const auto b_grid_desc_k0_n_k1 = + amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1( + karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB)); + const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC)); + + GridwiseGemm::template Run(karg.p_a_grid, + karg.p_b_grid, + karg.p_c_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_m_n); +#else + ignore = karg; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + template + __host__ static auto CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1); + } + + template + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock) * MPerBlock; + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; + } + + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); } + + // Argument + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + K0{CalculateK0(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "K0:" << K0 << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t K0; + }; + + // Argument + struct Argument : public Problem, public tensor_operation::device::BaseArgument + { + __host__ Argument(const FloatAB* p_a_grid_, + const FloatAB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_} + { + } + + const FloatAB* p_a_grid; + const FloatAB* p_b_grid; + FloatC* p_c_grid; + }; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + // denorm test fix, required to work around fp16 mfma issue + // we convert fp16->fp32->bf16 and execute bf16 mfma instruction + // when mfma if fixed, remove this section and update + // FloatABAdjusted -> FloatAB throughout this file +#if CK_GFX90A_DENORM_WORKAROUND + using FloatABAdjusted = conditional_t, ck::bhalf_t, FloatAB>; +#else + using FloatABAdjusted = FloatAB; +#endif + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size_aligned + b_block_space_size_aligned) * sizeof(FloatAB); + } + + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + // check gridwise gemm pipeline + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + using Block2CTileMap = BlockToCTileMap_M00_N0_M01Adapt; + + template + __device__ static void Run(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + const auto block_2_ctile_map = + Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)}; + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0), + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatABAdjusted, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatABAdjusted, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + FloatABAdjusted, + FloatABAdjusted, + FloatAcc, + decltype(a_block_desc_k0_m_k1), + decltype(b_block_desc_k0_n_k1), + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + K1, + LoopSched>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_grid_buf); + } + } +}; + +template +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext + : GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 +{ + using Parent = + GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3; + + using typename Parent::GridwiseGemmPipe; + using typename Parent::Problem; + + using Parent::I1; + + using Parent::K1; + + __device__ static auto + MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock; + const auto KPad = K0Pad * K1Value; + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(K0Pad, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + } + + __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + return transform_tensor_descriptor(c_grid_desc_m_n, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + return transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_pass_through_transform(M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + __host__ static constexpr bool CheckValidity(const Problem& problem) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXDL)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.M % MPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.N % NPerBlock == 0)) + { + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(problem.K0 % K0PerBlock == 0)) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.K % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.M % ABlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(problem.N % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + else + { + if(problem.K % BBlockTransferSrcScalarPerVector != 0) + { + return false; + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp new file mode 100755 index 000000000000..7d8e94c00166 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp @@ -0,0 +1,617 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + a_element_op, + b_element_op, + c_element_op, + c_block_cluster_adaptor); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_b_k0_m_k1_grid_desc; + ignore = b_b_k0_n_k1_grid_desc; + ignore = c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = c_block_cluster_adaptor; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + if(!block_2_ctile_map.CheckValidity(c_m_n_grid_desc)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = K0 > K0PerBlock; + + return has_main_k0_block_loop; + } + + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_m_n_grid_desc); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + template + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!c_block_cluster_adaptor.ValidCTileIndex( + make_tuple(block_work_idx[I1], block_work_idx[I2]), + make_tuple(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I0), + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetLength(I1)))) + { + return; + } + + const index_t k_batch_id = block_work_idx[I0]; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // Initialize C + c_thread_buf.Clear(); + + // main body + if constexpr(HasMainKBlockLoop) + { + index_t k0_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, a_block_slice_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, b_block_slice_copy_step); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2]), + c_element_op}; + + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp new file mode 100755 index 000000000000..256b495c6e98 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp @@ -0,0 +1,1130 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4r2_simplified(typename GridwiseGemm::Argument karg, + const Block2CTileMap& b2c_map, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + constexpr index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte(); + + __shared__ uint8_t p_shared[shared_size]; + + GridwiseGemm::template Run( + karg, static_cast(p_shared), b2c_map, a_element_op, b_element_op, c_element_op); +#else + ignore = karg; + ignore = b2c_map; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + static constexpr auto M01 = 1; + static constexpr auto N01 = 1; + + static constexpr auto gemm_padder = + tensor_operation::device::GemmPadder{ + MPerBlock, NPerBlock, K1* K0PerBlock}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + struct Argument : public ck::tensor_operation::device::BaseArgument + { + const FloatA* p_a_grid; + const FloatB* p_b_grid; + FloatC* p_c_grid; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t MPadded; + index_t NPadded; + index_t KPadded; + index_t K0Padded; + index_t k_batch; + + Argument(const FloatA* p_a_grid_, + const FloatB* p_b_grid_, + FloatC* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t MPadded_, + index_t NPadded_, + index_t KPadded_, + index_t K0Padded_, + index_t k_batch_) + : p_a_grid(p_a_grid_), + p_b_grid(p_b_grid_), + p_c_grid(p_c_grid_), + M(M_), + N(N_), + K(K_), + StrideA(StrideA_), + StrideB(StrideB_), + StrideC(StrideC_), + MPadded(MPadded_), + NPadded(NPadded_), + KPadded(KPadded_), + K0Padded(K0Padded_), + k_batch(k_batch_) + { + } + + void Print() const + { + std::cout << "arg {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KP:" << KPadded << ", " + << "K0Padded:" << K0Padded << ", " + << "KB:" << k_batch << "}" << std::endl; + } + }; + + __host__ __device__ static auto CalculateGridSize(const Argument& karg) + { + return std::make_tuple(math::integer_divide_ceil(karg.N, NPerBlock), + math::integer_divide_ceil(karg.M, MPerBlock), + karg.k_batch); + } + + // prefer this to be called on host + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateK0Padded(index_t K, index_t K_Batch = 1) + { + // k_batch * k0 * k0_per_block * k1 + auto K_t = K_Batch * K0PerBlock * K1; + return (K + K_t - 1) / K_t * K0PerBlock; + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K0Padded = CalculateK0Padded(K, K_Batch); + return K_Batch * K0Padded * K1; + } + + __host__ __device__ static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, + index_t MPad, + index_t K, + index_t StrideA, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto a_grid_desc_m_k = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock; + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) + { + const auto a_grid_desc_m_kpad = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + a_grid_desc_m_kpad, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, + index_t NPad, + index_t N, + index_t StrideB, + index_t KBatch, + index_t K0Padded, + index_t KPad) + { + const auto b_grid_desc_k_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); + } + }(); + + if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) + { + + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding) + { + // const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock; + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else if constexpr(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding) + { + const auto b_grid_desc_kpad_n = transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return transform_tensor_descriptor( + b_grid_desc_kpad_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + else + { + return transform_tensor_descriptor( + b_grid_desc_k_n, + make_tuple(make_unmerge_transform(make_tuple(KBatch, K0Padded, K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + } + } + + __host__ __device__ static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) + { + const auto c_grid_desc_m_n = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + return gemm_padder.PadCDescriptor_M_N(c_grid_desc_m_n); + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto c_block_size = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); + + return math::max(a_block_space_size * sizeof(LDSTypeA) + + b_block_space_size * sizeof(LDSTypeB), + c_block_size * sizeof(FloatC)); + } + + __host__ __device__ static constexpr bool CheckValidity(const Argument& karg) + { + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.k_batch * K0PerBlock * K1; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CBlockTransferScalarPerVector_NWaveNPerXDL (" + << CBlockTransferScalarPerVector_NWaveNPerXDL << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + const auto num_k_loop = karg.K0Padded / K0PerBlock; + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The number of k loops (" << num_k_loop + << ") value is not supported by GridwiseGemm Pipeline." + << " K0Padded: " << karg.K0Padded << ", K0PerBlock: " << K0PerBlock << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + + return true; + } + + __host__ __device__ static auto GetKPad(index_t K, index_t KBatch) + { + const index_t K0Padded = + math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; + const index_t KPad = KBatch * K0Padded * K1; + return KPad; + } + + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0Padded) + { + const index_t num_loop = K0Padded / K0PerBlock; + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc& c_m_n_grid_desc) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + return transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + template + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CGridDesc& c_m_n_grid_desc, index_t /* M01 */, index_t /* N01 */, index_t KBatch) + { + return BlockToCTileMap_KSplit_M00_N0_M01Adapt( + c_m_n_grid_desc, 8, KBatch); + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + } + + // return block_id to C matrix tile idx (m0, n0, k_split) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap() + { + return BlockToCTileMap_3DGrid_KSplit(); + } + + using CGridDesc_M_N = remove_cvref_t; + using DefaultBlock2CTileMap = remove_cvref_t; + + template + __device__ static void Run(const Argument& karg, + void* __restrict__ p_shared_block, + const Block2CTileMap& block_2_ctile_map, + const AElementwiseOperation a_element_op = AElementwiseOperation{}, + const BElementwiseOperation b_element_op = BElementwiseOperation{}, + const CElementwiseOperation c_element_op = CElementwiseOperation{}) + { + const FloatA* p_a_grid = karg.p_a_grid; + const FloatB* p_b_grid = karg.p_b_grid; + FloatC* p_c_grid = karg.p_c_grid; + const auto a_b_k0_m_k1_grid_desc = MakeAGridDescriptor_KBatch_K0_M_K1( + karg.M, karg.MPadded, karg.K, karg.StrideA, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto b_b_k0_n_k1_grid_desc = MakeBGridDescriptor_KBatch_K0_N_K1( + karg.K, karg.NPadded, karg.N, karg.StrideB, karg.k_batch, karg.K0Padded, karg.KPadded); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(c_grid_desc_m_n); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // divide block work by [KBatch, M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I2]); + const index_t k_batch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatA, + LDSTypeA, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_element_op, + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatB, + LDSTypeB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_element_op, + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + LDSTypeA, + LDSTypeB, + FloatAcc, + decltype(a_k0_m_k1_block_desc), + decltype(b_k0_n_k1_block_desc), + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + K1, + LoopSched, + ComputeTypeA, + ComputeTypeB>(); + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + auto p_a_block = reinterpret_cast(p_shared_block); + auto p_b_block = reinterpret_cast(p_a_block + a_block_space_size); + + constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_b_k0_m_k1_grid_desc.GetLength(I1) * a_b_k0_m_k1_grid_desc.GetLength(I3)) / + (K0PerBlock * K1)); + + const auto gridwise_gemm_pipeline = GridwiseGemmPipe{}; + + gridwise_gemm_pipeline.template Run(a_b_k0_m_k1_grid_desc, + a_b_k0_m_k1_block_desc, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_b_k0_n_k1_grid_desc, + b_b_k0_n_k1_block_desc, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // output: register to global memory + { + constexpr index_t MWave = MPerBlock / (MRepeat * MPerXDL); + constexpr index_t NWave = NPerBlock / (NRepeat * NPerXDL); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_block_desc_mblock_mperblock_nblock_nperblock = + GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared_block), + c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_unmerge_transform(make_tuple(CShuffleMRepeatPerShuffle, + M1, + M2, + M3, + M4)), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_freeze_transform(I0), // freeze nblock + make_unmerge_transform(make_tuple(CShuffleNRepeatPerShuffle, + N1, + N2))), // M1 = MWave, M2 * M3 * M4 = MPerXDL + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerXDL, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + FloatC, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXDL, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun + {c_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMRepeatPerShuffle * MWave * MPerXDL, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, CShuffleNRepeatPerShuffle * NWave * NPerXDL); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, -CShuffleNRepeatPerShuffle * NWave * NPerXDL); + + static_for<0, MRepeat, CShuffleMRepeatPerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, NRepeat, CShuffleNRepeatPerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run(c_block_desc_mblock_mperblock_nblock_nperblock, + c_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, mxdlperwave_forward_step); + } + }); + } + } + + static std::string GetTypeString() + { + auto str = std::stringstream(); + + // clang-format off + str << "GemmXdlSplitKCShuffle_" + << getGemmSpecializationString(GemmSpec) << "_" + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << "_" + << "B" << BlockSize << "_" + << "Vec" << ABlockTransferSrcScalarPerVector << "x" + << BBlockTransferSrcScalarPerVector << "x" + << CBlockTransferScalarPerVector_NWaveNPerXDL << "_" + << MPerBlock << "x" + << NPerBlock << "x" + << K0PerBlock << "x" + << K1 ; + // clang-format on + + return str.str(); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp new file mode 100755 index 000000000000..15c2da9d3221 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp @@ -0,0 +1,735 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r1( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = a_grid_desc_ak0_m_ak1; + ignore = b_grid_desc_bk0_n_bk1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatCShuffle, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_AK0_M_AK1, + typename BGridDesc_BK0_N_BK1, + typename CGridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + index_t AK1Value, + index_t BK1Value, + index_t MPerXdl, + index_t NPerXdl, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_AK0_M_AK1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_BK0_N_BK1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1, + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0 = Number{}; + static constexpr auto BK0 = Number{}; + static constexpr auto AK1 = Number{}; + static constexpr auto BK1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + constexpr auto max_lds_align = AK1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0, Number{}, AK1), + make_tuple(Number{} * AK1, AK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(AK0, Number{}, AK1), max_lds_align); + } + }(); + + return a_block_desc_ak0_m_ak1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + constexpr auto max_lds_align = BK1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(BK0, Number{}, BK1), + make_tuple(Number{} * BK1, BK1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(BK0, Number{}, BK1), max_lds_align); + } + }(); + + return b_block_desc_bk0_n_bk1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatCShuffle)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + // static_assert(is_known_at_compile_time>::value && + // is_known_at_compile_time>::value, + // "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1); + const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1); + const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K / KPerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t< + decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + CGridDesc_M_N{}))>; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + void* __restrict__ p_shared, + const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1, BK1); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1); + constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) || + (is_same::value && lcm_AK1_BK1 <= 8) || + ((is_same::value || is_same::value) && + lcm_AK1_BK1 < 32)) + ? true + : false; + constexpr auto is_scale_mfma = false; + constexpr index_t k_pack = math::max( + lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0); + + // gridwise GEMM pipeline + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // LDS to global + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatCShuffle, // typename SrcData, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_shuffle_block_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp new file mode 100755 index 000000000000..e22bfb64397e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp @@ -0,0 +1,761 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r2( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1, + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t< + decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + CGridDesc_M_N{}))>; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t< + decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + C0GridDesc_M_N{}))>; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + NumGemmKPrefetchStage>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple( + make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r2< + ThisThreadBlock, // index_t BlockSize, + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp new file mode 100755 index 000000000000..3da5e66018fd --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp @@ -0,0 +1,799 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v3r3( + const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CElementwiseOperation c_element_op, + const Block2CTileMap block_2_ctile_map) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ + defined(__gfx94__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_c_grid, + p_c0_grid, + p_c1_grid, + p_shared, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + a_element_op, + b_element_op, + c_element_op, + block_2_ctile_map); +#else + ignore = p_a_grid; + ignore = p_b_grid; + ignore = p_c_grid; + ignore = p_c0_grid; + ignore = p_c1_grid; + ignore = a_grid_desc_k0_m_k1; + ignore = b_grid_desc_k0_n_k1; + ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; + ignore = block_2_ctile_map; +#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) +} + +template < + index_t BlockSize, + typename FloatAB, + typename FloatAcc, + typename FloatC, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename AGridDesc_K0_M_K1, + typename BGridDesc_K0_N_K1, + typename CGridDesc_M_N, + typename C0GridDesc_M_N, + typename C1GridDesc_M_N, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CElementwiseOperation, + index_t MPerBlock, + index_t NPerBlock, + index_t K0PerBlock, + index_t MPerXdl, + index_t NPerXdl, + index_t K1Value, + index_t MXdlPerWave, + index_t NXdlPerWave, + typename ABlockTransferThreadClusterLengths_K0_M_K1, + typename ABlockTransferThreadClusterArrangeOrder, + typename ABlockTransferSrcAccessOrder, + index_t ABlockTransferSrcVectorDim, + index_t ABlockTransferSrcScalarPerVector, + index_t ABlockTransferDstScalarPerVector_K1, + bool AThreadTransferSrcResetCoordinateAfterRun, + bool ABlockLdsExtraM, + typename BBlockTransferThreadClusterLengths_K0_N_K1, + typename BBlockTransferThreadClusterArrangeOrder, + typename BBlockTransferSrcAccessOrder, + index_t BBlockTransferSrcVectorDim, + index_t BBlockTransferSrcScalarPerVector, + index_t BBlockTransferDstScalarPerVector_K1, + bool BThreadTransferSrcResetCoordinateAfterRun, + bool BBlockLdsExtraN, + index_t CShuffleMXdlPerWavePerShuffle, + index_t CShuffleNXdlPerWavePerShuffle, + typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + index_t CBlockTransferScalarPerVector_NWaveNPerXdl, + index_t NumGemmKPrefetchStage = 1, + PipelineVersion PipelineVer = PipelineVersion::v1> +struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + using ThisThreadBlock = ThisThreadBlock; + + using GridwiseGemmPipe = remove_cvref_t< + decltype(GridwiseGemmPipeline_Selector())>; + + __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return a_block_desc_k0_m_k1; + } + + __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1() + { + constexpr auto max_lds_align = K1; + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + return b_block_desc_k0_n_k1; + } + + __host__ __device__ static constexpr auto + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + Number{}, + I1, + Number{}, + Number{})); + + return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + constexpr auto max_lds_align = K1; + + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = + math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + constexpr auto c_block_size = + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned + b_block_space_size_aligned) * + sizeof(FloatAB), + c_block_size * sizeof(FloatC)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ __device__ static constexpr bool + CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDesc_M_N& c_grid_desc_m_n, + const Block2CTileMap& block_2_ctile_map) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + const auto M = a_grid_desc_k0_m_k1.GetLength(I1); + const auto N = b_grid_desc_k0_n_k1.GetLength(I1); + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && + K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) && + K1 == b_grid_desc_k0_n_k1.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) + return false; + + // check gridwise gemm pipeline + const auto num_k_loop = K0 / K0PerBlock; + + if(!GridwiseGemmPipe::IsSupported(num_k_loop)) + { + return false; + } + + if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) + { + return false; + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / (K0PerBlock * K1); + + return GridwiseGemmPipe::CalculateHasMainLoop(num_loop); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + const CGridDesc_M_N_& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto MBlock = M / MPerBlock; + const auto NBlock = N / NPerBlock; + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple( + MBlock, Number{}, Number{})), + make_unmerge_transform(make_tuple( + NBlock, Number{}, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{})); + + return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap( + const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */) + { + return BlockToCTileMap_M00_N0_M01Adapt( + c_grid_desc_m_n); + } + using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t< + decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + CGridDesc_M_N{}))>; + + using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t< + decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + C0GridDesc_M_N{}))>; + + using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl = + remove_cvref_t< + decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl( + C1GridDesc_M_N{}))>; + + using DefaultBlock2CTileMap = + remove_cvref_t; + + template + __device__ static void + Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const FloatC* __restrict__ p_c0_grid, + const FloatC* __restrict__ p_c1_grid, + void* __restrict__ p_shared, + const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1, + const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1, + const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl& + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + const AElementwiseOperation& a_element_op, + const BElementwiseOperation& b_element_op, + const CElementwiseOperation& c_element_op, + const Block2CTileMap& block_2_ctile_map) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c0_grid_buf = make_dynamic_buffer( + p_c0_grid, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + auto c1_grid_buf = make_dynamic_buffer( + p_c1_grid, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0); + + // divide block work by [M, N] + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetLength(I3)))) + { + return; + } + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_grid_desc_k0_m_k1), + decltype(a_block_desc_k0_m_k1), + ABlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_grid_desc_k0_m_k1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_k0_m_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_grid_desc_k0_n_k1), + decltype(b_block_desc_k0_n_k1), + BBlockTransferSrcAccessOrder, + Sequence<1, 0, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_k0_n_k1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_k0_n_k1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = + math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + static_cast(p_shared) + a_block_space_size_aligned, + b_block_desc_k0_n_k1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + + // gridwise GEMM pipeline + const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock); + + GridwiseGemmPipe::template Run(a_grid_desc_k0_m_k1, + a_block_desc_k0_m_k1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_k0_n_k1, + b_block_desc_k0_n_k1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + blockwise_gemm, + c_thread_buf, + K0BlockMainLoop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl = + GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(); + + auto c_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl + .GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_tuple(make_freeze_transform(I0), // freeze mblock + make_pass_through_transform( + Number{}), // M0 (MXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_freeze_transform(I0), // freeze nblock + make_pass_through_transform( + Number{}), // N0 (NXdlPerWave) per + // shuffle + make_unmerge_transform( + make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<>{}, + Sequence<0>{}, + Sequence<2, 4, 5, 6>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<3, 7>{}) + + ); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle, + MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle, + NWave * NPerXdl>, // BlockSliceLengths, + CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl, + Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder, + FloatC, // typename Src0Data, + FloatC, // typename Src1Data, + FloatC, // typename Src2Data, + FloatC, // typename DstData, + decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl), + Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder, + 5, // index_t VectorDim, + CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector, + true, // bool ThreadTransferSrc0ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc1ResetCoordinateAfterRun, + false, // bool ThreadTransferSrc2ResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(0, 0, 0, 0, 0, 0), + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0), + c_element_op}; + + constexpr auto mxdlperwave_forward_step = + make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0); + constexpr auto nxdlperwave_forward_step = + make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0); + constexpr auto nxdlperwave_backward_step = + make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0); + + static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) { + constexpr auto mxdlperwave = mxdlperwave_iter; + + static_for<0, + NXdlPerWave, + CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) { + constexpr bool nxdlperwave_forward_sweep = + (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0); + + constexpr index_t nxdlperwave_value = + nxdlperwave_forward_sweep + ? nxdlperwave_iter + : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle); + + constexpr auto nxdlperwave = Number{}; + + // make sure it's safe to do ds_write + block_sync_lds(); + + // VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_buf); + + // make sure it's safe to do ds_read + block_sync_lds(); + + // LDS to global + c_block_copy_lds_to_global.Run( + c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_block_buf, + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c0_grid_buf, + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c1_grid_buf, + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + c_grid_buf); + + // move on nxdlperwave dimension + if constexpr(nxdlperwave_forward_sweep && + (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_forward_step); + } + else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0)) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + nxdlperwave_backward_step); + } + }); + + // move on mxdlperwave dimension + if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle) + { + c_block_copy_lds_to_global.MoveSrc1SliceWindow( + c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveSrc2SliceWindow( + c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + + c_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, + mxdlperwave_forward_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp new file mode 100755 index 000000000000..3d5066d52d34 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -0,0 +1,2596 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_gemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseMoeGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); + + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); + + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + + return std::make_tuple(gridx, gridy, 1); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack / KGroup); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + struct Problem + { + __host__ __device__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize, + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[IsInputGemm + ? token_offset + : token_offset * + problem.TopK + + (fused_token >> + 24)] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + 2>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + num_k_block_main_loop); + } + else + { + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // mul scales + const float* p_sorted_weights_0 = p_ds_grid[I0]; + const float* p_scale_b = p_ds_grid[I1]; + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + if(p_sorted_weights_0 != nullptr && p_scale_b != nullptr) + { + if constexpr(PerTokenQuant) + { + constexpr index_t scale_stride = (IsInputGemm ? 2 : 1); + p_scale_b += expert_id * problem.N * scale_stride + block_n_id * NPerBlock + + get_warp_local_1d_id() % NWave * NPerXdl + threadIdx.x % NPerXdl; + } + else + { + p_scale_b += expert_id; + } + + vector_type scale_token_ids; + vector_type topk_weights; + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + const float scale_b = p_scale_b[n0 * NWave * NPerXdl * PerTokenQuant]; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(PerTokenQuant) + { + scale_token_ids = + *c_style_pointer_cast*>( + p_sorted_token_ids + m_pos); + } + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + float scale_a = [&]() { + if constexpr(PerTokenQuant) + { + index_t fused_token = scale_token_ids.AsType()[m4]; + const index_t token_offset = fused_token & 0xffffff; + return token_offset < problem.NumTokens + ? p_sorted_weights_0[IsInputGemm + ? token_offset + : token_offset * + problem.TopK + + (fused_token >> + 24)] + : 0.0; + } + else + { + return p_sorted_weights_0[0]; + } + }(); + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + const float scale_up = + p_scale_b[(n0 * NWave * NPerXdl + problem.N) * + PerTokenQuant]; + float gate = scale_a * scale_b * c_thread_buf[cidx]; + float up = scale_a * scale_up * c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = + scale_a * scale_b * c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) * + topk_weights.AsType()[m4]; + } + } + }); + }); + }); + }); + } + else + { + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = topk_weights.AsType()[m4] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + } + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp new file mode 100755 index 000000000000..f092c9c1eb65 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -0,0 +1,2668 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_gemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + karg.p_a_scale_grid, + karg.p_b_scale_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseMoeGemmBlockScale +{ + using AScaleType = float; + using BScaleType = float; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + static constexpr auto BlockSizeNumber = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk); + static constexpr index_t KGroup = []() { + if constexpr(is_same_v, f8_t>) + // On gfx950, we have a mfma that required 32 f8 elements as input, + // splited into 2 groups of 16 f8 elements. + // the 2 groups is not contiguous in the B preshuffed layout. + // and we do not want it to be contiguous in the B preshuffled layout + // because a memory instruction can only read 16 f8 elements at a time. + return mfma_selector::selected_mfma.k_per_blk == 32 ? 2 : 1; + else + return 1; + }(); + static constexpr index_t KLane = + mfma_selector::GetKPerXdlops() / mfma_selector::GetK1PerXdlops(); + static constexpr index_t KRepeat = KPerBlock / KLane / (KPack / KGroup); + static constexpr index_t NLane = NPerXdl; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + return std::make_tuple(gridx, gridy, 1); + } + + __host__ __device__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ __device__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ __device__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack / KGroup); + } + + __host__ __device__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ __device__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ __device__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor( + make_tuple(N0 / NWave, NWave, K0, NkSwizzleNumber), + make_tuple(NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, I1)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + using DsGridDesc_M_N = remove_cvref_t; + + struct Problem + { + __host__ __device__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + std::array StrideDs_, + index_t StrideC_, + const AScaleType* p_a_scale_grid_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const BDataType* p_b_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AScaleType* p_a_scale_grid; + const BScaleType* p_b_scale_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(LDSTypeA) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(LDSTypeA)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(LDSTypeA) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(LDSTypeA))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, I1, Number{}, Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(LDSTypeA) / APackedSize, + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { +#if DEBUG_LOG + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { +#if DEBUG_LOG + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { +#if DEBUG_LOG + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { +#if DEBUG_LOG + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + +#endif // DEBUG_LOG + return false; + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens + : problem.NumTokens * problem.TopK, + ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockK)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + // get each thread's offset in the scale tensor + // A scale + const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM; + + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray scale_gather_offsets; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { + const index_t fused_token = + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scale_gather_offsets(m0) = + token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK); + }); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2_gather, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false, + MXdlPerWave>( + a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + const BScaleType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + c_thread_buf_up, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + + c_scale_thread_desc, + c_thread_buf, + + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; + + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion, elementwise + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + const DDataType* ptr_ = p_ds_grid[i]; + // hack logic here to support different kind of strides. todo fix it. + // ascale t, 1; bscale E, N, 1, move ptr to E + return make_dynamic_buffer( + ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = token_offset * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + const AScaleType* p_a_scale_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(IsInputGemm ? problem.NumTokens + : problem.NumTokens * problem.TopK, + ScaleBlockM), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || + token0 >= problem.NumTokens) + return; + StaticallyIndexedArray + gather_offsets; //= p_sorted_token_ids[token_pos]; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockK)); + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // scale + constexpr index_t ScaleSliceSizeM = MXdlPerWave; + constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + // ScaleSliceSizeK is last dimension in A/B scale for vector memory access + // ScaleSliceSizeK is first dimension in C scale for packed math + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + auto a_thread_offset = + get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, Number{})); + + // get each thread's offset in the scale tensor + // A scale + const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM; + + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray scale_gather_offsets; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { + const index_t fused_token = + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scale_gather_offsets(m0) = static_cast(token_offset) * + math::integer_divide_ceil(problem.K, ScaleBlockK); + }); + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2_gather, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false, + MXdlPerWave>( + a_scale_grid_desc_am_ak, make_multi_index(0, 0), scale_gather_offsets); + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + // constexpr auto a_scale_thread_slice_copy_step = make_multi_index(0, 1); + constexpr auto a_scale_thread_slice_copy_step = + make_tuple(make_multi_index(0, 0), make_multi_index(0, ScaleSliceSizeK)); + constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); + + constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * static_cast(expert_stride) / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + const BScaleType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / BPackedSize; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_scale_thread_desc, + c_thread_buf, + c_thread_buf_up, + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + b_scale_thread_slice_copy_step, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_scale_thread_desc, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_desc, + a_scale_thread_copy, + a_scale_grid_buf, + a_scale_thread_slice_copy_step, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop); + } + + // shuffle C and write out + { + + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // transposed XDL + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + // TODO: hacky, fix it! + // only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); + + static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock); + static_assert(M0 * M1 * M2 == MPerBlock); + static_assert(N4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m2 = threadIdx.x % get_warp_size() % M2; + + float topk_weight; + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + if constexpr(MulRoutedWeight) + { + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2; + topk_weight = p_ds_grid[I0][m_pos]; + } + static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk + static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, n2 * N4 + n4)); + constexpr auto cidx = Number{}; + if constexpr(IsInputGemm) // gu fusion, elementwise + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weight; + up = up * topk_weight; + } + if constexpr(is_same_v, pk_i4_t>) + { + gate *= 16; + up *= 16; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf(cidx) = gate * up; + } + } + else + { + if constexpr(MulRoutedWeight) + { + c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight; + } + } + + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2)), // M2 = MPerXdl + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 1; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray + scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = token_offset * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp new file mode 100755 index 000000000000..5f8e524fb234 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp @@ -0,0 +1,2942 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +#if 0 +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} +#endif + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseMoeGemmMX +{ + using LDSTypeA = ADataType; + using LDSTypeB = BDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base + // KPack in packed data types for pk A/B + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk / APackedSize); + + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + + return std::make_tuple(gridx, gridy, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + constexpr auto permuted_desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MPad, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(MPad), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return a_grid_desc; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + (GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MPadding)), + "f4x2_pk_t does not support K padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + struct Problem + { + __host__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / BPackedSize, + StrideScaleB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead; + } + + // Calculate A scale offset + if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); + } + else if constexpr(is_same_v) + { + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + b_scale_k_split_offset = + k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; + index_t b_scale_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // contiguous in LDS + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // contiguous in lds + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return b_lds_block_desc_permuted; + } + else // RowMajor B + { + constexpr auto WaveSize = 64; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + if constexpr(IsInputGemm) + { + return math::max(a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType) * 2, + c_block_size * sizeof(CShuffleDataType)); + } + else + { + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + + return false; + } + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + static_assert(is_same_v && + is_same_v, + "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!"); + +#if 0 + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = a_element_op; + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset); + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // Gride buffer creation + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + // A, B scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< + ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + IndexType, + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + auto b_block_buf_up = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride, + b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + // A + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + // Gate and Up + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_buf_up, + b_block_slice_copy_step, + // C + c_thread_buf, + c_thread_buf_up, + // A scale + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + // Gate and Up scale + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, // A + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, // B + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, // B scale + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + // mul scales + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; + if constexpr(MulRoutedWeight) + { + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + + /*float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + //up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = up;*/ + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) + // per shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +#endif + + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = a_element_op; + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // Gride buffer creation + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + // A, B scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< + ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + IndexType, + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride, + b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + // lds ping pong buffers for up + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + auto b_block_buf_up_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_up_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_block_bufs_up = make_tuple(b_block_buf_up_ping, b_block_buf_up_pong); + + auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad< + ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); + + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + // A + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + // Gate and Up + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_bufs_up, + b_block_slice_copy_step, + // C + c_thread_buf, + c_thread_buf_up, + // A scale + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + // B scale + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, // A + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, // B + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, // B scale + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + // mul scales + + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; + if constexpr(MulRoutedWeight) + { + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp new file mode 100755 index 000000000000..9ccd3342626f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp @@ -0,0 +1,2849 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_gather.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +#if 0 +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid, + karg.p_a_scale_grid, + karg.p_b_grid, + karg.p_b_scale_grid, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} +#endif + +template +struct GridwiseMoeGemmMXBNS +{ + using LDSTypeA = ADataType; + using LDSTypeB = BDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), mfma_selector::selected_mfma.k_per_blk / APackedSize); + + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + + return std::make_tuple(gridx, gridy, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + GemmSpec != GemmSpecialization::Default), + "f4x2_pk_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + struct Problem + { + __host__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / BPackedSize, + StrideScaleB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead; + } + + // Calculate A scale offset + if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); + } + else if constexpr(is_same_v) + { + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + b_scale_k_split_offset = + k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); + } + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; + index_t b_scale_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr auto b_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return b_lds_block_desc_permuted; + } + else // RowMajor B + { + constexpr auto WaveSize = 64; + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + if constexpr(IsInputGemm) + { + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)) * + 2, + c_block_size * sizeof(CShuffleDataType)); + } + else + { + return math::max((a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + c_block_size * sizeof(CShuffleDataType)); + } + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + + return false; + } + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // Gride buffer creation + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + // A, B scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + auto b_block_buf_up = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType) + + b_block_space_size_aligned * sizeof(BDataType)), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride, + b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy_up = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + // A + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + // Gate and Up + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_buf_up, + b_block_slice_copy_step, + // C + c_thread_buf, + c_thread_buf_up, + // A scale + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + // Gate and Up scale + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, // A + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, // B + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, // B scale + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + // mul scales + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; + + if constexpr(MulRoutedWeight) + { + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + + /*float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + //up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = up;*/ + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) + // per shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + +#if 0 + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<1, 2, 0, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // mul scales + + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0 / MXdlPack, + n0 / NXdlPack, + m0 % MXdlPack, + n0 % NXdlPack, + m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple(make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +#endif +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp new file mode 100755 index 000000000000..be85528f2826 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp @@ -0,0 +1,2761 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_gather_direct_load.hpp" + +#define DEBUG_LOG 0 + +namespace ck { + +// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same +// kernel function Blockers: +// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on +// two lds chunks. +// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds +// buffer when we declare __shared__ inside blkgemmpipe + +enum Activation +{ + gelu_and_mul = 0, + silu_and_mul = 1 +}; + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared_0, + p_shared_1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseMoeGemmMX_BPreshuffle +{ + using LDSTypeA = ADataType; + using LDSTypeB = BDataType; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I9 = Number<9>{}; + + static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock = + CDEShuffleBlockTransferScalarPerVectors{}[I0]; + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto MXdlPack = 2; + static constexpr auto NXdlPack = 2; + static constexpr auto KXdlPack = 2; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base + // KPack in packed data types for pk A/B + + static constexpr index_t APackedSize = packed_size_v; + static constexpr index_t BPackedSize = packed_size_v; + + using mfma_selector = MfmaSelector; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, mfma_selector::selected_mfma.k_per_blk / APackedSize); + + static constexpr index_t NLane = NPerXdl; + static constexpr index_t KLane = 64 / NLane; + static constexpr index_t NWave = NPerBlock / NPerXdl / NXdlPerWave; + static constexpr index_t KRepeat = KPerBlock / KLane / KPack; + + // static constexpr index_t NumTokens = 1; + static constexpr index_t SortedTileSize = MPerBlock; + + using mx_scale_t = e8m0_bexp_t; + static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t); + static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t); + static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0, + "A scale pack data type too large!"); + static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, + "B scale pack data type too large!"); + + static constexpr auto MakeDsGridPointer() + { + return generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + + return static_cast(nullptr); + }, + Number{}); + } + + using DsGridPointer = decltype(MakeDsGridPointer()); + + using ThisThreadBlock = ThisThreadBlock; + + __host__ static auto CalculateGridSize(index_t M, index_t N) + { + const index_t nblock = math::integer_divide_ceil(N, NPerBlock); + const index_t mblock = math::integer_divide_ceil(M, MPerBlock); + const index_t gridx = NSwizzle ? nblock * mblock : nblock; + const index_t gridy = NSwizzle ? 1 : mblock; + + return std::make_tuple(gridx, gridy, 1); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateBN0Shuffled(index_t N) + { + return math::integer_divide_ceil(N, NLane); + } + __host__ static auto CalculateBK0Shuffled(index_t K) + { + return math::integer_divide_ceil(K, KLane * KPack); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + if constexpr(IsXor) + { + constexpr auto permuted_desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + else + { + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<4>{}, Sequence<0, 1, 2, 3>{})); + } + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0) + { + constexpr index_t NkSwizzleNumber = Number{}; + return make_naive_tensor_descriptor_packed( + make_tuple(N0 / NWave / NXdlPack, NWave, NXdlPack, K0, NkSwizzleNumber)); + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + static_assert(!(is_same_v, f4x2_pk_t> && + GemmSpec != GemmSpecialization::Default), + "f4x2_pk_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor( + ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor( + BBlockDesc_BK0_N_BK1{}); + } + + template + __host__ __device__ static auto MakeCGridDescriptor_M_N( + IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template + __host__ __device__ static auto + MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + __host__ __device__ static auto MakeDsGridDescriptor_M_N( + index_t M, index_t MPad, index_t N, index_t NPad, std::array StrideDs) + { + return generate_tuple( + [&](auto i) { + using DLayout = remove_cvref_t>; + return MakeDGridDescriptor_M_N(M, MPad, N, NPad, StrideDs[i]); + }, + Number{}); + } + + template + __device__ static constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + return generate_tuple( + [&](auto i) { + return MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n[i], MBlock, NBlock); + }, + Number{}); + } + + struct Problem + { + __host__ Problem(index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t KBatch_) + : NumTokens{NumTokens_}, + TopK{TopK_}, + M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideDs{StrideDs_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)}, + BN0Shuffled{CalculateBN0Shuffled(N_)}, + BK0Shuffled{CalculateBK0Shuffled(K_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "NumTokens:" << NumTokens << ", " + << "TopK:" << TopK << ", " + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t NumTokens; + index_t TopK; + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + std::array StrideDs; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + // FOR PRESHUFFLE ONLY + index_t BN0Shuffled; + index_t BK0Shuffled; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const index_t* p_sorted_token_ids_, + const index_t* p_sorted_expert_ids_, + const index_t* p_max_token_id_, + const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + std::array p_ds_grid_, + CDataType* p_c_grid_, + index_t NumTokens_, + index_t TopK_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + std::array StrideDs_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_) + : Problem{NumTokens_, + TopK_, + M_, + N_, + K_ / APackedSize, + StrideA_ / APackedSize, + StrideScaleA_, + StrideB_ / BPackedSize, + StrideScaleB_, + StrideDs_, + StrideC_, + k_batch_}, + p_sorted_token_ids{p_sorted_token_ids_}, + p_sorted_expert_ids{p_sorted_expert_ids_}, + p_max_token_id{p_max_token_id_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_ds_grid{}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_} + { + + // populate pointer, desc for Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType_ = remove_cvref_t>; + + // D pointer + p_ds_grid(i) = static_cast(p_ds_grid_[i]); + }); + } + + const index_t* p_sorted_token_ids; + const index_t* p_sorted_expert_ids; + const index_t* p_max_token_id; + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + DsGridPointer p_ds_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + }; + + struct SplitKBatchOffset + { + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + // KPack * NLane * KLane * K0 * N0 + b_k_split_offset = k_id * karg.KRead * NPerXdl; + } + + // Calculate A scale offset + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize) * MXdlPack * + MPerXdl / scale_pack_size_a; + + // Calculate B scale offset + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize) * NXdlPack * + NPerXdl / scale_pack_size_b; + + if(k_id < karg.KBatch - 1) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; + index_t b_scale_k_split_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // contiguous in LDS + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return a_lds_block_desc_permuted; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto WaveSize = 64; + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + return make_naive_tensor_descriptor_packed(make_tuple(Number{}, + I1, + Number{}, + Number{}, + Number{})); + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max(a_block_space_size_aligned * sizeof(ADataType), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + + return false; + } + } + } + + // check gridwise gemm pipeline +#if 0 + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } +#endif + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, + // NPerBlock>; + +#if 0 + template + __device__ static void Run(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerBlock), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged"); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K / APackedSize; + }); + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = + __builtin_amdgcn_readfirstlane(problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + // A, B scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather< + ThisThreadBlock, + AElementwiseOperation, + ck::tensor_operation::element_wise::PassThrough, + InMemoryDataOperationEnum::Set, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + IndexType, + 1, + BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<1, 2, 0, 3>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), + a_block_desc_ak0_m_ak1.GetElementSpaceSize() / APackedSize); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride / BPackedSize, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack / KGroup * (get_thread_local_1d_id() % WarpSize))); + const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2; + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride, + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + c_thread_buf_up, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + // mul scales + static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock); + static_assert(M4 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave, 1>{}([&](auto n0) { + static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 + + m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4; + if constexpr(MulRoutedWeight) + { + topk_weights = *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, m2 * M4 + m4)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m4]; + up = up * topk_weights.AsType()[m4]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m4] * c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie([&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 1; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +#endif + + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = a_element_op; + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, + problem.MPadded, + problem.K, + problem.KPadded, + problem.StrideA, + problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, + problem.MPadded, + problem.N, + problem.NPadded, + problem.StrideC); + + // We pad the M unconditionaly for Scale + const auto Padded_Scale_M = + math::integer_divide_ceil(problem.M, ScaleBlockSize) * ScaleBlockSize; + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(Padded_Scale_M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / APackedSize)) * + MPerXdl * MXdlPack / scale_pack_size_a, + 64 * KXdlPack * MXdlPack / scale_pack_size_a, + 1)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b), + make_tuple(math::integer_divide_ceil(problem.K * problem.KBatch, + (ScaleBlockSize / BPackedSize)) * + NPerXdl * NXdlPack / scale_pack_size_b, + 64 * KXdlPack * NXdlPack / scale_pack_size_b, + 1)); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if(expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = + __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr(NSwizzle) + { + const index_t ecnt_prefix = p_max_token_id[1 + expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix; + const index_t expert_swizzle = + ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane( + bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = + __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + return {nid, mid}; + } + else + { + return {blockIdx.x, blockIdx.y}; + } + }(); + + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + const index_t token0 = + __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads; + + if(token_pos >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads]; + index_t token_offset = fused_token & 0xffffff; + if constexpr(!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = static_cast(token_offset) * problem.K; + }); + + const index_t expert_stride = + __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1)); + const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane( + problem.N * (IsInputGemm ? 2 : 1) * + math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize)); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave / NXdlPack); + + // Gride buffer creation + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + + // A, B scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise direct to LDS copy + auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad< + ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + IndexType, + 1>(a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<0, 1, 2, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + decltype(c_thread_buf) c_thread_buf_up; + + StaticBufferTupleOfVector + c_thread_buf_fp32; + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // a and b scale processing + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; + + auto a_thread_offset_m = waveId_m; + + // get each thread's offset int the scale tensor + const index_t token_scale_pos = block_m_id * MPerBlock; + if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) + return; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + // B scale load + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + if constexpr(IsInputGemm) + { + const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2; + const auto b_grid_buf_up = make_dynamic_buffer( + p_b_grid_up + expert_id * expert_stride, + b_grid_desc_bpreshuffled.GetElementSpaceSize()); + auto b_blockwise_copy_up = + ThreadwiseTensorSliceTransfer_v2{}, + I1, + Number{}, + Number{}, + Number{}>, + Sequence<0, 1, 2, 3, 4>, + 4, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + 0, + KPack * (get_thread_local_1d_id() % WarpSize))); + const BScaleDataType* p_b_scale_grid_up = + p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType); + const auto b_scale_grid_buf_up = make_dynamic_buffer( + p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType), + b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run( + // A + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + // Gate and Up + b_grid_desc_bpreshuffled, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_blockwise_copy_up, + b_grid_buf, + b_grid_buf_up, + b_block_bufs, + b_block_slice_copy_step, + // C + c_thread_buf, + c_thread_buf_up, + // A scale + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + // B scale + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_thread_copy_up, + b_scale_grid_buf, + b_scale_grid_buf_up, + num_k_block_main_loop); + } + else + { + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, // A + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, // B + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, // C + a_scale_grid_desc_am_ak, // A scale + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, // B scale + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + } + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); + + // mul scales + + static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock); + static_assert(M5 == 4); + const index_t m1 = get_warp_local_1d_id() / NWave; + const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl; + + vector_type topk_weights; // for gemm2 only + static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) { + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack + static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack + static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk + const index_t m_pos = block_m_id * MPerBlock + + m0 * M2 * M1 * M3 * M4 * M5 + + m1 * M2 * M3 * M4 * M5 + + imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5; + if constexpr(MulRoutedWeight) + { + topk_weights = + *c_style_pointer_cast*>( + p_ds_grid[I2] + m_pos); + } + static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size + constexpr index_t c_offset = + blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( + make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5)); + constexpr auto cidx = Number{}; + + if constexpr(IsInputGemm) // gu fusion + { + if constexpr(ActivationOperation == + Activation::silu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Silu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + else if(ActivationOperation == Activation::gelu_and_mul) + { + float gate = c_thread_buf[cidx]; + float up = c_thread_buf_up[cidx]; + if constexpr(MulRoutedWeight) + { + gate = gate * topk_weights.AsType()[m5]; + up = up * topk_weights.AsType()[m5]; + } + tensor_operation::element_wise::Gelu{}(gate, gate); + c_thread_buf_fp32(cidx) = gate * up; + } + } + else + { + c_thread_buf_fp32(cidx) = c_thread_buf[cidx]; + if constexpr(MulRoutedWeight) + { + c_thread_buf_fp32(cidx) = + topk_weights.AsType()[m5] * + c_thread_buf_fp32[cidx]; + } + } + }); + }); + }); + }); + }); + }); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4, + M5)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) + // per shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + return make_dynamic_buffer( + p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = + container_concat(make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, + // block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = 3; // hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make + // Sequence support + // arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + IndexType, + 1, // ScatterDim + true, // OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + >{c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = + CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = + CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = + block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + IndexType token_offset = fused_token & 0xffffff; + if constexpr(IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + scatter_offsets(m0) = static_cast(token_offset) * problem.N; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf_fp32, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp new file mode 100755 index 000000000000..61d0f9e0d557 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_permute.hpp @@ -0,0 +1,339 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/tensor_description/cluster_descriptor.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, + const OutGridDesc out_grid_desc, + const InDataType* p_in_global, + OutDataType* p_out_global, + const ElementwiseOperation elementwise_op, + const Block2TileMap block_2_tile_map) +{ + __shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()]; + + GridwisePermute::Run(in_grid_desc, + out_grid_desc, + p_in_global, + p_out_global, + p_shared, + elementwise_op, + block_2_tile_map); +} + +template +struct GridwisePermute +{ + static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension()); + static_assert(3 <= InGridDesc::GetNumOfDimension()); + static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim && + SrcVectorDim < InGridDesc::GetNumOfDimension()); + static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim && + DstVectorDim < OutGridDesc::GetNumOfDimension()); + static_assert(SrcVectorDim != DstVectorDim); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ThisThreadBlock = ThisThreadBlock; + + struct Block2TileMap + { + static constexpr index_t NumDim = InGridDesc::GetNumOfDimension(); + static_assert(3 <= NumDim); + + static constexpr auto I0 = Number<0>{}; + + Block2TileMap() = delete; + Block2TileMap(const Block2TileMap&) = default; + Block2TileMap(Block2TileMap&&) = delete; + + ~Block2TileMap() = default; + + Block2TileMap& operator=(const Block2TileMap&) = delete; + Block2TileMap& operator=(Block2TileMap&&) = delete; + + explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {} + + __host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const + { + const auto N0 = + math::integer_divide_ceil(desc.GetLength(Number{}), NPerBlock); + const auto H0 = + math::integer_divide_ceil(desc.GetLength(Number{}), HPerBlock); + const auto W0 = + math::integer_divide_ceil(desc.GetLength(Number{}), WPerBlock); + + const index_t grid_size = N0 * H0 * W0; + + return grid_size; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + static_assert(TopIdx::Size() == 1); + + auto block_1d_id = idx_top[I0]; + + const auto N0 = + math::integer_divide_ceil(desc_.GetLength(Number{}), NPerBlock); + const auto H0 = + math::integer_divide_ceil(desc_.GetLength(Number{}), HPerBlock); + const auto W0 = + math::integer_divide_ceil(desc_.GetLength(Number{}), WPerBlock); + + block_1d_id = block_1d_id % (N0 * H0 * W0); + + index_t idx_N0 = block_1d_id / (H0 * W0); + index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0; + index_t idx_W0 = block_1d_id % W0; + + return make_tuple(idx_N0, idx_H0, idx_W0); + } + + private: + const InGridDesc desc_; + }; + + using DefaultBlock2TileMap = Block2TileMap; + + // use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay + __host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock() + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, Number{}), + make_tuple(Number{}, + Number{}, + I1)); + } + + // for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions + // into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] -> + // [(d(0) x d(1) x ...), d(N - 2), d(N - 1)] + template + __host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc) + { + constexpr index_t NumDim = GridDesc::GetNumOfDimension(); + static_assert(3 <= NumDim); + + const auto merged_desc = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(generate_tuple( + [&](auto I) { return desc.GetLength(I); }, Number{})), + make_pass_through_transform(desc.GetLength(Number{})), + make_pass_through_transform(desc.GetLength(Number{}))), + make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number{}), + Sequence{}, + Sequence{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return merged_desc; + } + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto in_block_desc_nperblock_hperblock_wperblock = + GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock(); + + return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() * + sizeof(InDataType); + } + + __host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc) + { + return DefaultBlock2TileMap{desc}; + } + + __host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc, + const OutGridDesc& out_grid_desc) + { + constexpr index_t NumDim = InGridDesc::GetNumOfDimension(); + + // check if we only swap last 2 dimensions + bool valid = true; + static_for<0, NumDim - 2, 1>{}([&](auto I) { + if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I)) + { + valid = false; + } + }); + + return valid && + (in_grid_desc.GetLength(Number{}) == + out_grid_desc.GetLength(Number{})) && + (in_grid_desc.GetLength(Number{}) == + out_grid_desc.GetLength(Number{})); + } + + template + __device__ static void Run(const InGridDesc in_grid_desc, + const OutGridDesc out_grid_desc, + const InDataType* p_in_global, + OutDataType* p_out_global, + void* __restrict__ p_shared, + const ElementwiseOperation elementwise_op, + const Block2TileMap& block_2_tile_map) + { + auto in_global_buf = make_dynamic_buffer( + p_in_global, in_grid_desc.GetElementSpaceSize()); + + auto out_global_buf = make_dynamic_buffer( + p_out_global, out_grid_desc.GetElementSpaceSize()); + + // each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock); + + const index_t h_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock); + + const index_t w_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock); + + // create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer + constexpr auto in_block_desc_nperblock_hperblock_wperblock = + GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock(); + + auto in_block_buf = make_dynamic_buffer( + static_cast(p_shared), + in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize()); + + using BlockSliceLengths = Sequence; + using InBlockTransferAccessOrder = Sequence<0, 1, 2>; + + constexpr index_t SrcVectorDimAfterMerge = + SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3); + constexpr index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge; + + using ck::tensor_operation::element_wise::PassThrough; + + // merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x + // ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)] + const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc); + + // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS + auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + ElementwiseOperation, + PassThrough, + InMemoryDataOperationEnum::Set, + BlockSliceLengths, + InBlockTransferThreadClusterLengths, + InBlockTransferThreadClusterArrangeOrder, + InDataType, + InDataType, + decltype(in_grid_desc_n_h_w), + decltype(in_block_desc_nperblock_hperblock_wperblock), + InBlockTransferAccessOrder, + InBlockTransferAccessOrder, + SrcVectorDimAfterMerge, + 2, + SrcScalarPerVector, + 1, + 1, + 1, + true, + true>(in_grid_desc_n_h_w, + make_multi_index( + n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid), + PassThrough{}, + in_block_desc_nperblock_hperblock_wperblock, + make_multi_index(0, 0, 0), + PassThrough{}); + + // merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x + // ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)] + const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc); + + // create transposed view of output tensor + const auto out_grid_desc_n_h_w = transform_tensor_descriptor( + out_grid_desc_n_w_h, + make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)), + make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)), + make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{})); + + // a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory + auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1< + ThisThreadBlock, + ElementwiseOperation, + PassThrough, + InMemoryDataOperationEnum::Set, + BlockSliceLengths, + InBlockTransferThreadClusterLengths, + InBlockTransferThreadClusterArrangeOrder, + InDataType, + OutDataType, + decltype(in_block_desc_nperblock_hperblock_wperblock), + decltype(out_grid_desc_n_h_w), + InBlockTransferAccessOrder, + InBlockTransferAccessOrder, + 2, + DstVectorDimAfterMerge, + 1, + DstScalarPerVector, + 1, + 1, + true, + true>(in_block_desc_nperblock_hperblock_wperblock, + make_multi_index(0, 0, 0), + PassThrough{}, + out_grid_desc_n_h_w, + make_multi_index( + n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid), + elementwise_op); + + in_global_load.Run(in_grid_desc_n_h_w, + in_global_buf, + in_block_desc_nperblock_hperblock_wperblock, + in_block_buf, + I0); + + out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock, + in_block_buf, + out_grid_desc_n_h_w, + out_global_buf, + I0); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp new file mode 100755 index 000000000000..c41eef8c45fa --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp @@ -0,0 +1,155 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation elementwise_op) +{ + GridwisePutElementwise1dFunctor::Run( + in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op); +} + +// output[indices] = input +template +struct GridwisePutElement_1D +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + __device__ static void Run(const InGrid1dDesc& in_grid_1d_desc, + const InDataType* __restrict__ p_in_global, + const IndexDataType* __restrict__ p_indices_global, + OutDataType* __restrict__ p_out_global, + const ElementwiseOperation& elementwise_op) + { + // Global Memory + const auto in_global_buf = make_dynamic_buffer( + p_in_global, in_grid_1d_desc.GetElementSpaceSize()); + + const auto indices_global_buf = + make_dynamic_buffer(p_indices_global, + in_grid_1d_desc.GetElementSpaceSize(), + NumericLimits::Lowest()); + + // VGPR + StaticBuffer in_thread_buf; + StaticBuffer indices_thread_buf; + + // Thread id, Block id and index + const index_t thread_global_id = get_thread_global_1d_id(); + const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize); + const index_t blockSize = get_block_size(); + const index_t blockPerGrid = get_grid_size(); + const auto M = in_grid_1d_desc.GetLength(I0); + const index_t loop_step = blockPerGrid * blockSize * InVectorSize; + const auto loop_step_index = make_multi_index(loop_step); + + auto in_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + auto indices_global_load = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0>, // DimAccessOrder + 0, // SrcVectorDim + InVectorSize, // ScalarPerVector + 1, // SrcScalarStrideInVector + false>{in_grid_1d_desc, thread_global_offset}; + + index_t num_iter = M / loop_step; + do + { + in_global_load.Run(in_grid_1d_desc, + in_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + in_thread_buf); + + in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}( + [&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); }); + + indices_global_load.Run(in_grid_1d_desc, + indices_global_buf, + thread_buffer_desc_m, + make_tuple(I0), + indices_thread_buf); + + indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index); + + static_for<0, InVectorSize, 1>{}([&](auto iM) { + if(indices_thread_buf[iM] >= 0) + { + if constexpr(MemOp == InMemoryDataOperationEnum::Set) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) = + ck::type_convert(in_thread_buf[iM]); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd) + { + atomic_add(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax) + { + atomic_max(p_out_global + indices_thread_buf[iM], + ck::type_convert(in_thread_buf[iM])); + } + else if constexpr(MemOp == InMemoryDataOperationEnum::Add) + { + // User should guarantee each index in p_indices_global is different + *(p_out_global + indices_thread_buf[iM]) += + ck::type_convert(in_thread_buf[iM]); + } + else + { + static_assert(MemOp == InMemoryDataOperationEnum::Set || + MemOp == InMemoryDataOperationEnum::AtomicAdd || + MemOp == InMemoryDataOperationEnum::AtomicMax || + MemOp == InMemoryDataOperationEnum::Add); + } + } + }); + + } while(--num_iter); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp new file mode 100755 index 000000000000..41352fabeb36 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffer_desc, + DataType* const __restrict__ p_global, + DataType value) + +{ + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + constexpr auto I0 = Number<0>{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t thread_global_id = block_global_id * BlockSize + thread_local_id; + + StaticBuffer value_buf; + + value_buf(I0) = value; + + constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + + auto global_buf = make_dynamic_buffer( + p_global, grid_1d_buffer_desc.GetElementSpaceSize()); + + if(thread_global_id < grid_1d_buffer_desc.GetElementSize()) + { + auto threadwise_store = ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + grid_1d_buffer_desc, make_multi_index(thread_global_id), PassThroughOp{}); + + threadwise_store.Run( + val_buff_desc, make_tuple(I0), value_buf, grid_1d_buffer_desc, global_buf); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp new file mode 100755 index 000000000000..0ad36b418a42 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_set_multiple_buffer_value.hpp @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" + +namespace ck { + +template +__global__ void +kernel_multiple_buffer_set_value(const Grid1dBufferDescTuple grid_1d_buffer_desc_tuple, + DataTypePointerTuple p_global_tuple, + DataTypeTuple value_tuple) + +{ + static_assert(NumBuffer == DataTypePointerTuple::Size() && NumBuffer == DataTypeTuple::Size(), + "The tuple size should be same as NumBuffer!"); + + static_for<0, NumBuffer, 1>{}([&](auto iB) { + using DataTypePointer = remove_cvref_t; + using DataTypeFromPointer = remove_pointer_t; + using DataType = remove_cvref_t; + + static_assert(is_same::value, + "Types in tuples does not match!"); + }); + + constexpr auto I0 = Number<0>{}; + + const index_t thread_global_id = get_thread_global_1d_id(); + + auto value_buf_tuple = generate_tuple( + [&](auto iB) { + using DataType = remove_cvref_t; + + return StaticBuffer{}; + }, + Number{}); + + static_for<0, NumBuffer, 1>{}([&](auto iB) { + static_for<0, 1, 1>{}([&](auto J) { value_buf_tuple(iB)(J) = value_tuple[iB]; }); + }); + + auto global_buf_tuple = generate_tuple( + [&](auto iB) { + return make_dynamic_buffer( + p_global_tuple(iB), grid_1d_buffer_desc_tuple[iB].GetElementSpaceSize()); + }, + Number{}); + + constexpr auto val_buff_desc = make_naive_tensor_descriptor_packed(make_tuple(Number<1>{})); + + static_for<0, NumBuffer, 1>{}([&](auto iB) { + using DataType = remove_cvref_t; + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + auto threadwise_store = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0>, + 0, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + grid_1d_buffer_desc_tuple[iB], make_multi_index(thread_global_id), PassThroughOp{}); + + threadwise_store.Run(val_buff_desc, + make_tuple(I0), + value_buf_tuple(iB), + grid_1d_buffer_desc_tuple[iB], + global_buf_tuple(iB)); + }); +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp new file mode 100755 index 000000000000..5f56ac6fc466 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp @@ -0,0 +1,407 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/reduction_common.hpp" +#include "ck/utility/reduction_operator.hpp" +#include "ck/utility/reduction_functions_accumulate.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, + const GridDesc_M_K out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) +{ + GridwiseReduction::Run(in_grid_desc_m_k, + out_grid_desc_m_k, + block_group_size, + num_k_block_tile_iteration, + alpha, + p_in_value_global, + beta, + p_out_value_global); +}; + +template +struct GridwiseSoftmax_mk_to_mk +{ + static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) || + (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) && + (KThreadSliceSize % OutDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k, + const GridDesc_M_K& out_grid_desc_m_k, + index_t block_group_size, + index_t num_k_block_tile_iteration, + AccDataType alpha, + const InDataType* const __restrict__ p_in_value_global, + AccDataType beta, + OutDataType* const __restrict__ p_out_value_global) + { + if constexpr(SweepOnce) + { + num_k_block_tile_iteration = 1; + } + + // LDS + __shared__ AccDataType p_reduce_work_buffer[BlockSize]; + + auto out_global_val_buf = make_dynamic_buffer( + p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize()); + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + StaticBuffer + in_thread_buf; + + StaticBuffer + out_thread_buf; + + StaticBuffer max_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + max_value_buf(I) = reduce::Max::template GetIdentityValue(); + }); + + StaticBuffer accu_value_buf; + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + accu_value_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t blkgroup_id = block_global_id / block_group_size; + const index_t block_local_id = block_global_id % block_group_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + using ThreadBufferLengths = Sequence; + constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + // Normally, 0 as invalid element value is adequate since 0 makes no contribution to + // accumulated result. However, in stable softmax, all values 0s or not are subtracted by + // another value_max. As numbers become non-zero, effectively it allows invalid values to + // slip through and contribute to the accumulated result. + // + // The trick here is leveraging the fact that many math functions (add, sub, exp, ...) + // propagate NaNs when operands have NaNs involved. By initialiing invalid element value + // with NaN, an invalid value doing math manipulations is still NaN, which in turn can still + // be identified as an invalid value. We can then discard the invalid values which + // originally failed the bound check during accumulation. This allows to ignore values that + // failed bound check even after multiple math manipulations. + // + // NOTE: reset coordinate after every step because the same threadwise copy will sweep + // through global memory 3 times back and forth + auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2( + in_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2( + out_grid_desc_m_k, + make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dst_store = + ThreadwiseTensorSliceTransfer_v1r3( + out_grid_desc_m_k, + make_multi_index( + blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + constexpr auto in_thread_copy_fwd_step = + make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize); + constexpr auto in_thread_copy_bwd_step = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + /// + /// max(x) + /// + using BlockwiseMaxReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Max, + false, // param ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseMaxReduce = + ThreadwiseReduction>; + + const auto in_global_val_buf = make_dynamic_buffer( + p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize()); + + index_t reducedTiles = 0; + do + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + + ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I)); + block_sync_lds(); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + /// + /// sum(exp(x - max(x))) + /// + using BlockwiseSumReduce = PartitionedBlockwiseReduction< + AccDataType, + BlockSize, + ThreadClusterLengths_M_K, + ThreadClusterArrangeOrder, + reduce::Add, + false, // ignored + detail::AccumulateWithNanIgnore>; + + using ThreadwiseSumReduce = + ThreadwiseReduction>; + + reducedTiles = 0; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + // do element-wise pre-reduction operation + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + math::exp(in_thread_buf(Number{}) - max_value_buf(iM)); + }); + }); + + ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + + block_sync_lds(); // wait for reading being complete before writing to LDS + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I)); + block_sync_lds(); + }); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + + /// + /// softmax + /// + reducedTiles = 0; + if(float_equal_zero{}(beta)) + { + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + else + { + StaticBuffer + in_prior_dst_buf; + do + { + if constexpr(!SweepOnce) + { + threadwise_src_load.Run(in_grid_desc_m_k, + in_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_thread_buf); + } + threadwise_dst_load.Run(out_grid_desc_m_k, + out_global_val_buf, + thread_buffer_desc, + make_tuple(I0, I0), + in_prior_dst_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset = + thread_buffer_desc.CalculateOffset(make_tuple(iM, iK)); + out_thread_buf(Number{}) = + alpha * math::exp(in_thread_buf(Number{}) - max_value_buf(iM)) / + accu_value_buf(iM) + + beta * in_prior_dst_buf(Number{}); + }); + }); + + threadwise_dst_store.Run(thread_buffer_desc, + make_tuple(I0, I0), + out_thread_buf, + out_grid_desc_m_k, + out_global_val_buf); + + threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step); + + reducedTiles++; + } while(reducedTiles < num_k_block_tile_iteration); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp new file mode 100755 index 000000000000..287b4e542168 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp @@ -0,0 +1,319 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" + +namespace ck { + +template +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + __global__ void kernel_sparse_embeddings_forward_layernorm( + OutType* p_out, + const ck::Array p_embs, + const ck::Array p_indexes, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc out_grid_desc, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) +{ + GridwiseSparseEmbedding::Run( + p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op); +} + +template +struct GridwiseSparseEmbeddingsForwardLayernorm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t WaveSize = 64; + + static_assert(BlockSize == RowClusterSize * DimClusterSize, + "Invalid cluster distribution within block"); + static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise"); + + static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, ""); + static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, ""); + + static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize); + static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize); + + static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), ""); + static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks; + static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks; + + using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}))); + + using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using ThreadClusterLength = Sequence; + + using BlockwiseWelford = + BlockwiseWelford>; + + __device__ static void Run(OutType* p_out, + const ck::Array p_embs, + const ck::Array p_indexes, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) + { + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + constexpr auto thread_cluster_desc = + make_cluster_descriptor(Sequence{}, Sequence<0, 1>{}); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_dim_cluster_id = thread_cluster_idx[I0]; + const auto thread_row_cluster_id = thread_cluster_idx[I1]; + + const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize); + + const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize; + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize; + + constexpr auto thread_buf_size = + DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize; + constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed( + make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize)); + constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize; + constexpr auto mean_var_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(DimSubBlocks, DimThreadSize)); + constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize; + constexpr auto gamma_beta_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize)); + + ck::Array, + NumEmbeddings> + in_thread_bufs; + ck::Array, NumEmbeddings> + index_bufs; + + StaticBuffer acc_thread_buf; + + StaticBuffer + gamma_thread_buf; + StaticBuffer + beta_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + ck::Array, NumEmbeddings> emb_vectors; + auto emb_a = emb_vectors[0]; + using src_vector_t = typename decltype(emb_a)::type; + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_; + + auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(EmbType) * RowVectorSize; + static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + IndexType index = index_bufs[i_embedding_][Number{}]; + + int32x4_t emb_res = make_wave_buffer_resource_with_default_range( + p_embs[i_embedding_] + index * RowPerBlock); + emb_vectors(i_embedding_).template AsType()(I0) = + amd_buffer_load_impl(emb_res, thread_offset, 0); + }); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + in_thread_bufs(i_embedding_)(Number{}) = + ck::type_convert( + emb_vectors[i_embedding_].template AsType()[i_row_vec_]); + }); + }); + }); + }; + + auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + auto in_data_refs = generate_tie( + [&](auto i_embedding_) -> const auto& { + return in_thread_bufs(i_embedding_)(Number{}); + }, + Number{}); + auto out_data_refs = generate_tie( + [&](auto) -> auto& { return acc_thread_buf(Number{}); }, + Number<1>{}); + unpack2(emb_elementwise_op, out_data_refs, in_data_refs); + }); + }); + }; + + auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); + }); + }); + }; + + auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) { + int32x4_t out_res = + make_wave_buffer_resource_with_default_range(p_out + index_start * RowPerBlock); + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + vector_type_maker_t out_vector; + using dst_vector_t = typename decltype(out_vector)::type; + + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + auto divisor = + 1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number{}) + epsilon); + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto gamma_beta_offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + + auto acc_val = acc_thread_buf[Number{}]; + acc_val = (acc_val - mean_thread_buf(Number{})) * divisor; + acc_val = acc_val * gamma_thread_buf[Number{}] + + beta_thread_buf[Number{}]; + + out_vector.template AsType()(Number{}) = + type_convert(acc_val); + }); + + index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(OutType) * RowVectorSize; + + amd_buffer_store_impl( + out_vector.template AsType()[Number<0>{}], + out_res, + thread_offset, + 0); + }); + }; + + // first load index + ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + // prefer use s_load + ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + index_bufs(i_embedding_)(i_idx_) = + p_indexes[i_embedding_][index_start + i_idx_.value]; + }); + }); + + // load gamma/beta + static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) { + vector_type_maker_t gamma_vector; + vector_type_maker_t beta_vector; + + index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(GammaDataType) * RowVectorSize; + index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(BetaDataType) * RowVectorSize; + + int32x4_t gamma_res = make_wave_buffer_resource_with_default_range(p_gamma); + int32x4_t beta_res = make_wave_buffer_resource_with_default_range(p_beta); + + gamma_vector.template AsType()(I0) = + amd_buffer_load_impl( + gamma_res, thread_offset_gamma, 0); + beta_vector.template AsType()(I0) = + amd_buffer_load_impl(beta_res, thread_offset_beta, 0); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + gamma_thread_buf(Number{}) = type_convert( + gamma_vector.template AsType()[Number{}]); + beta_thread_buf(Number{}) = type_convert( + beta_vector.template AsType()[Number{}]); + }); + }); + + static_for<0, thread_buf_size, 1>{}( + [&](auto I) { acc_thread_buf(I) = type_convert(0.0f); }); + + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) { + load_current_sub_row(i_dim_sub, Number<0>{}); + static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) { + load_current_sub_row(i_dim_sub, Number<1>{} + i_row); + accumulate_current_sub_row(i_dim_sub, i_row); + threadwise_welford_sub_row(i_dim_sub, i_row); + }); + accumulate_current_sub_row(i_dim_sub, Number{}); + threadwise_welford_sub_row(i_dim_sub, Number{}); + + // blockwise welford + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_); + }); + + // store + static_for<0, RowSubBlocks, 1>{}( + [&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); }); + }); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp new file mode 100755 index 000000000000..7c3e37276599 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm_builtins.hpp @@ -0,0 +1,322 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" + +namespace ck { + +template +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + __global__ void kernel_sparse_embeddings_forward_layernorm( + OutType* p_out, + const ck::Array p_embs, + const ck::Array p_indexes, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc out_grid_desc, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) +{ + GridwiseSparseEmbedding::Run( + p_out, p_embs, p_indexes, p_gamma, p_beta, out_grid_desc, epsilon, emb_elementwise_op); +} + +template +struct GridwiseSparseEmbeddingsForwardLayernorm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr index_t WaveSize = 64; + + static_assert(BlockSize == RowClusterSize * DimClusterSize, + "Invalid cluster distribution within block"); + static_assert(RowClusterSize % WaveSize == 0, "need to be wavewise"); + + static_assert(DimPerBlock % (DimClusterSize * DimThreadSize) == 0, ""); + static_assert(RowPerBlock % (RowClusterSize * RowVectorSize) == 0, ""); + + static constexpr auto DimSubBlocks = DimPerBlock / (DimClusterSize * DimThreadSize); + static constexpr auto RowSubBlocks = RowPerBlock / (RowClusterSize * RowVectorSize); + + static_assert((DimPerBlock % DimSubBlocks == 0) && (RowPerBlock % RowSubBlocks == 0), ""); + static constexpr auto DimPerSubBlock = DimPerBlock / DimSubBlocks; + static constexpr auto RowPerSubBlock = RowPerBlock / RowSubBlocks; + + using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}))); + + using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using ThreadClusterLength = Sequence; + + using BlockwiseWelford = + BlockwiseWelford>; + + __device__ static void Run(OutType* p_out, + const ck::Array p_embs, + const ck::Array p_indexes, + const GammaDataType* p_gamma, + const BetaDataType* p_beta, + const OutGridDesc, + const AccDataType epsilon, + const EmbElementwiseOperation emb_elementwise_op) + { + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + constexpr auto thread_cluster_desc = + make_cluster_descriptor(Sequence{}, Sequence<0, 1>{}); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_dim_cluster_id = thread_cluster_idx[I0]; + const auto thread_row_cluster_id = thread_cluster_idx[I1]; + + const auto wave_dim_id = __builtin_amdgcn_readfirstlane(thread_dim_cluster_id / WaveSize); + + const auto index_start = block_global_id * DimPerBlock + wave_dim_id * DimThreadSize; + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = RowSubBlocks * RowVectorSize; + + constexpr auto thread_buf_size = + DimSubBlocks * DimThreadSize * RowSubBlocks * RowVectorSize; + constexpr auto thread_buf_desc = make_naive_tensor_descriptor_packed( + make_tuple(DimSubBlocks, DimThreadSize, RowSubBlocks, RowVectorSize)); + constexpr auto mean_var_buf_size = DimSubBlocks * DimThreadSize; + constexpr auto mean_var_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(DimSubBlocks, DimThreadSize)); + constexpr auto gamma_beta_buf_size = RowSubBlocks * RowVectorSize; + constexpr auto gamma_beta_buf_desc = + make_naive_tensor_descriptor_packed(make_tuple(RowSubBlocks, RowVectorSize)); + + ck::Array, + NumEmbeddings> + in_thread_bufs; + ck::Array, NumEmbeddings> + index_bufs; + + StaticBuffer acc_thread_buf; + + StaticBuffer + gamma_thread_buf; + StaticBuffer + beta_thread_buf; + + StaticBuffer mean_thread_buf; + StaticBuffer var_thread_buf; + + auto load_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + ck::Array, NumEmbeddings> emb_vectors; + auto emb_a = emb_vectors[0]; + using src_vector_t = typename decltype(emb_a)::type; + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + constexpr auto current_dim = i_dim_sub_ * DimPerSubBlock + i_dim_vec_; + + auto thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(EmbType) * RowVectorSize; + static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + IndexType index = index_bufs[i_embedding_][Number{}]; + + __amdgpu_buffer_rsrc_t emb_res = + make_wave_buffer_resource_with_default_range_new(p_embs[i_embedding_] + + index * RowPerBlock); + emb_vectors(i_embedding_).template AsType()(I0) = + amd_buffer_load_impl(emb_res, thread_offset, 0); + }); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + in_thread_bufs(i_embedding_)(Number{}) = + ck::type_convert( + emb_vectors[i_embedding_].template AsType()[i_row_vec_]); + }); + }); + }); + }; + + auto accumulate_current_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + auto in_data_refs = generate_tie( + [&](auto i_embedding_) -> const auto& { + return in_thread_bufs(i_embedding_)(Number{}); + }, + Number{}); + auto out_data_refs = generate_tie( + [&](auto) -> auto& { return acc_thread_buf(Number{}); }, + Number<1>{}); + unpack2(emb_elementwise_op, out_data_refs, in_data_refs); + }); + }); + }; + + auto threadwise_welford_sub_row = [&](auto i_dim_sub_, auto i_row_sub_) { + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + + threadwise_welford.cur_count_++; + threadwise_welford.Update(mean_thread_buf(Number{}), + var_thread_buf(Number{}), + acc_thread_buf(Number{})); + }); + }); + }; + + auto threadwise_normalize_store_out = [&](auto i_dim_sub_, auto i_row_sub_) { + __amdgpu_buffer_rsrc_t out_res = + make_wave_buffer_resource_with_default_range_new(p_out + index_start * RowPerBlock); + static_for<0, DimThreadSize, 1>{}([&](auto i_dim_vec_) { + vector_type_maker_t out_vector; + using dst_vector_t = typename decltype(out_vector)::type; + + constexpr auto mean_var_offset = + mean_var_buf_desc.CalculateOffset(make_tuple(i_dim_sub_, i_dim_vec_)); + auto divisor = + 1 / __builtin_amdgcn_sqrtf(var_thread_buf(Number{}) + epsilon); + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto register_offset = thread_buf_desc.CalculateOffset( + make_tuple(i_dim_sub_, i_dim_vec_, i_row_sub_, i_row_vec_)); + constexpr auto gamma_beta_offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + + auto acc_val = acc_thread_buf[Number{}]; + acc_val = (acc_val - mean_thread_buf(Number{})) * divisor; + acc_val = acc_val * gamma_thread_buf[Number{}] + + beta_thread_buf[Number{}]; + + out_vector.template AsType()(Number{}) = + type_convert(acc_val); + }); + + index_t thread_offset = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(OutType) * RowVectorSize; + + amd_buffer_store_impl( + out_vector.template AsType()[Number<0>{}], + out_res, + thread_offset, + 0); + }); + }; + + // first load index + ck::static_for<0, DimPerBlock, 1>{}([&](auto i_idx_) { + // prefer use s_load + ck::static_for<0, NumEmbeddings, 1>{}([&](auto i_embedding_) { + index_bufs(i_embedding_)(i_idx_) = + p_indexes[i_embedding_][index_start + i_idx_.value]; + }); + }); + + // load gamma/beta + static_for<0, RowSubBlocks, 1>{}([&](auto i_row_sub_) { + vector_type_maker_t gamma_vector; + vector_type_maker_t beta_vector; + + index_t thread_offset_gamma = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(GammaDataType) * RowVectorSize; + index_t thread_offset_beta = (thread_row_cluster_id + i_row_sub_ * RowClusterSize) * + sizeof(BetaDataType) * RowVectorSize; + + __amdgpu_buffer_rsrc_t gamma_res = + make_wave_buffer_resource_with_default_range_new(p_gamma); + __amdgpu_buffer_rsrc_t beta_res = + make_wave_buffer_resource_with_default_range_new(p_beta); + + gamma_vector.template AsType()(I0) = + amd_buffer_load_impl( + gamma_res, thread_offset_gamma, 0); + beta_vector.template AsType()(I0) = + amd_buffer_load_impl(beta_res, thread_offset_beta, 0); + + static_for<0, RowVectorSize, 1>{}([&](auto i_row_vec_) { + constexpr auto offset = + gamma_beta_buf_desc.CalculateOffset(make_tuple(i_row_sub_, i_row_vec_)); + gamma_thread_buf(Number{}) = type_convert( + gamma_vector.template AsType()[Number{}]); + beta_thread_buf(Number{}) = type_convert( + beta_vector.template AsType()[Number{}]); + }); + }); + + static_for<0, thread_buf_size, 1>{}( + [&](auto I) { acc_thread_buf(I) = type_convert(0.0f); }); + + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + static_for<0, DimSubBlocks, 1>{}([&](auto i_dim_sub) { + load_current_sub_row(i_dim_sub, Number<0>{}); + static_for<0, RowSubBlocks - 1, 1>{}([&](auto i_row) { + load_current_sub_row(i_dim_sub, Number<1>{} + i_row); + accumulate_current_sub_row(i_dim_sub, i_row); + threadwise_welford_sub_row(i_dim_sub, i_row); + }); + accumulate_current_sub_row(i_dim_sub, Number{}); + threadwise_welford_sub_row(i_dim_sub, Number{}); + + // blockwise welford + static_for<0, mean_var_buf_size, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), threadwise_welford.cur_count_); + }); + + // store + static_for<0, RowSubBlocks, 1>{}( + [&](auto i_row) { threadwise_normalize_store_out(i_dim_sub, i_row); }); + }); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp new file mode 100755 index 000000000000..ddf0b4a58d68 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_tensor_rearrange(const InputGridDesc in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap block_2_tile_map, + const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ + defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \ + defined(__gfx12__)) + GridwiseTensorRearrangeKernel::Run(in_grid_desc, + p_in_global, + out_grid_desc, + p_out_global, + batch_count, + block_2_tile_map, + compute_ptr_offset_of_batch); +#else + ignore = in_grid_desc; + ignore = p_in_global; + ignore = out_grid_desc; + ignore = p_out_global; + ignore = batch_count; + ignore = block_2_tile_map; + ignore = compute_ptr_offset_of_batch; +#endif +} + +template +struct GridwiseTensorRearrange +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThisThreadBlock = ThisThreadBlock; + + __device__ static void Run(const InputGridDesc& in_grid_desc, + const InputDataType* __restrict__ p_in_global, + const OutputGridDesc& out_grid_desc, + OutputDataType* __restrict__ p_out_global, + const index_t batch_count, + const Block2ETileMap& block_2_tile_map, + const ComputePtrOffsetOfStridedBatch& compute_ptr_offset_of_batch) + { + const auto block_work_idx = + block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); + + const index_t k_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * KPerBlock); + + auto copy_global_to_global = + ThreadGroupTensorSliceTransfer_v7, + Tuple, + decltype(tie(in_grid_desc)), + decltype(tie(out_grid_desc)), + tensor_operation::element_wise::PassThrough, + Sequence(DstInMemOp)>, + Sequence, + ThreadClusterLengths, + Sequence<0, 1>, + Sequence<0, 1>, + I1, + ScalarPerVector, + Sequence, + Sequence>{ + in_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + out_grid_desc, + make_tuple(make_multi_index(m_block_data_idx_on_grid, k_block_data_idx_on_grid)), + tensor_operation::element_wise::PassThrough{}}; + + const index_t num_blocks_per_batch = + __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); + const index_t g_idx = + __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); + + // Global Memory + const index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); + + const auto in_global_buf = make_dynamic_buffer( + p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize()); + auto out_global_buf = make_dynamic_buffer( + p_out_global + c_batch_offset, out_grid_desc.GetElementSpaceSize()); + + copy_global_to_global.Run( + tie(in_grid_desc), tie(in_global_buf), tie(out_grid_desc), tie(out_global_buf)); + } + + __host__ static constexpr bool CheckValidity(const InputGridDesc& in_grid_desc, + const OutputGridDesc& out_grid_desc) + { + if(in_grid_desc.GetLength(I0) % MPerBlock != 0 || + in_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + if(out_grid_desc.GetLength(I0) % MPerBlock != 0 || + out_grid_desc.GetLength(I1) % KPerBlock != 0) + return false; + return true; + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp new file mode 100755 index 000000000000..8a0e16d7f6e0 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp @@ -0,0 +1,554 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// Tensor Shape +// dy, x = [M, K], gamma = [1, K], x_mean, inv_std = [M, 1] + +// Flow: +// def normalization_backward_x(dy, x, gamma, x_mean, inv_std, reduce_axis, reduce_size): +// ds = np.sum(dy * gamma * x, axis=reduce_axis, keepdims=True) +// db = np.sum(dy * gamma, axis=reduce_axis, keepdims=True) +// b = (db * x_mean - ds) * inv_std ** (3) / reduce_size +// c = -b * x_mean - db * inv_std / reduce_size +// dx = inv_std * dy * gamma + b * x + c +// return dx + +template +struct GridwiseNormalizationBwdData_mk_to_mk +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + static_assert( + ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) || + (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)), + "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"); + + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) || + (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)), + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using GammaThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + using DXThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M_K& dx_grid_desc_m_k, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DXDataType* const __restrict__ p_dx_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dx_global_val_buf = make_dynamic_buffer( + p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto gamma_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dx_thread_buf = StaticBuffer{}; + + auto ds_thread_buf = + StaticBuffer{}; + + auto db_thread_buf = + StaticBuffer{}; + + // thread id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dx_store = + ThreadwiseTensorSliceTransfer_v1r3( + dx_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize), + PassThroughOp{}); + + ComputeDataType reduce_size = type_convert( + dy_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + ds_thread_buf(I) = type_convert(0.0f); + db_thread_buf(I) = type_convert(0.0f); + }); + + // Separate sweep once and sweep twice pipeline + // Sweep once: for small k, if KThreadClusterSize * KThreadSliceSize > K + // we don't need to use loop to read x, dy, gamma twice + if constexpr(SweepOnce) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + } // end of sweep once + else // Sweep Twice pipeline + { + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + x_thread_buf[offset_m_k]; + + db_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k]; + }); + }); + } // end of first sweep + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, ds_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, db_thread_buf(I)); + }); + + // reverse read for using dy, gamma and x in the cache + constexpr auto thread_copy_bwd_step_m_k = make_multi_index(0, -K_BlockTileSize); + auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k; + + // move to tail + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k); + + // move from start to tail + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + // b = (db * x_mean - ds) * rstd ** (3) / reduce_size + // c = -b * x_mean - db * rstd / reduce_size + // dx = rstd * dy * gamma + b * x + c + + ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] - + ds_thread_buf[offset_m]; + + b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] / reduce_size; + + ComputeDataType c = -b * mean_thread_buf(offset_m_k); + + c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size; + + dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] * + gamma_thread_buf[offset_m_k] * + inv_std_thread_buf[offset_m_k] + + b * x_thread_buf[offset_m_k] + c; + }); + }); + + threadwise_dx_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + dx_thread_buf, + dx_grid_desc_m_k, + dx_global_val_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_bwd_step_m_k); + threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k); + } + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp new file mode 100755 index 000000000000..21248e3a0a80 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_gamma_beta.hpp @@ -0,0 +1,352 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" + +namespace ck { + +// dgamma = reduce_sum(dy * (x - mean) * inv_std) +// dbeta = reduce_sum(dy) +template +struct GridwiseNormalizationBwdGammaBeta_mk_to_k +{ + // if we just check ThreadSliceSize % VectorSize == 0, the performance may be poor (coalesce) + static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) || + (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)), + "Invalid thread slice sizes and/or dy vector sizes configuration, please check!"); + + static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) || + (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)), + "Invalid thread slice sizes and/or x vector sizes configuration, please check!"); + + // do not force SliceSize == MeanInvStdSrcVectorSize for groupnorm + static_assert( + ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || + (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0)), + "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize == DGammaDstVectorSize && MThreadSliceSize == DBetaDstVectorSize, + "Invalid thread slice sizes and/or dx vector sizes configuration, please check!"); + + using ThreadClusterLengths_M_K = Sequence; + + using DYThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using XThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using MeanInvStdThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = DYThreadBufferDimAccessOrder; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + using ThreadBufferLengths_M = Sequence; + + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + + __device__ static void Run(const GridDesc_M_K& dy_grid_desc_m_k, + const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& mean_grid_desc_m_k, + const GridDesc_M_K& inv_std_grid_desc_m_k, + const GridDesc_M& dgamma_grid_desc_m, + const GridDesc_M& dbeta_grid_desc_m, + index_t num_k_block_tile_iteration, + const DYDataType* const __restrict__ p_dy_global, + const XDataType* const __restrict__ p_x_global, + const MeanInvStdDataType* const __restrict__ p_mean_global, + const MeanInvStdDataType* const __restrict__ p_inv_std_global, + DGammaDataType* const __restrict__ p_dgamma_global, + DBetaDataType* const __restrict__ p_dbeta_global) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + // Global + const auto dy_global_val_buf = make_dynamic_buffer( + p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize()); + + const auto inv_std_global_val_buf = make_dynamic_buffer( + p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize()); + + auto dgamma_global_val_buf = make_dynamic_buffer( + p_dgamma_global, dgamma_grid_desc_m.GetElementSpaceSize()); + + auto dbeta_global_val_buf = make_dynamic_buffer( + p_dbeta_global, dbeta_grid_desc_m.GetElementSpaceSize()); + + // VGPR + auto dy_thread_buf = StaticBuffer{}; + + auto x_thread_buf = StaticBuffer{}; + + auto mean_thread_buf = StaticBuffer{}; + + auto inv_std_thread_buf = StaticBuffer{}; + + auto dgamma_thread_buf = + StaticBuffer{}; + + auto dbeta_thread_buf = + StaticBuffer{}; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // IO + auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2( + dy_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_mean_load = + ThreadwiseTensorSliceTransfer_v2( + mean_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_inv_std_load = + ThreadwiseTensorSliceTransfer_v2( + inv_std_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * KThreadSliceSize)); + + auto threadwise_dgamma_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DGammaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dgamma_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_dbeta_store = + ThreadwiseTensorSliceTransfer_v1r3, + 0, + DBetaDstVectorSize, + InMemoryDataOperationEnum::Set, + 1, + true>( + dbeta_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + dgamma_thread_buf(I) = type_convert(0.0f); + dbeta_thread_buf(I) = type_convert(0.0f); + }); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + threadwise_dy_load.Run(dy_grid_desc_m_k, + dy_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + dy_thread_buf); + + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf); + + threadwise_mean_load.Run(mean_grid_desc_m_k, + mean_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + mean_thread_buf); + + threadwise_inv_std_load.Run(inv_std_grid_desc_m_k, + inv_std_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + inv_std_thread_buf); + + threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_mean_load.MoveSrcSliceWindow(mean_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_inv_std_load.MoveSrcSliceWindow(inv_std_grid_desc_m_k, + thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + constexpr auto offset_m = + Number{}; + + static_for<0, KThreadSliceSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + Number{}; + + dgamma_thread_buf(offset_m) += + dy_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] * + (x_thread_buf[offset_m_k] - mean_thread_buf[offset_m_k]); + + dbeta_thread_buf(offset_m) += dy_thread_buf[offset_m_k]; + }); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, dbeta_thread_buf(I)); + block_sync_lds(); + BlockwiseSumReduce::Reduce(reduce_work_buf, dgamma_thread_buf(I)); + }); + + if(thread_k_cluster_id == 0) + { + threadwise_dgamma_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dgamma_thread_buf, + dgamma_grid_desc_m, + dgamma_global_val_buf); + + threadwise_dbeta_store.Run(thread_buffer_desc_m, + make_tuple(I0), + dbeta_thread_buf, + dbeta_grid_desc_m, + dbeta_global_val_buf); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp new file mode 100755 index 000000000000..1bee1b93f976 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp @@ -0,0 +1,612 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" + +#include "ck/utility/reduction_operator.hpp" +#include "ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp" +#include "ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Y = Normalization(X, Beta, Gamma) +template +struct GridwiseNormalizationNaiveVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + static_assert(XSrcVectorSize == YDstVectorSize); + static_assert(XSrcVectorSize == GammaSrcVectorSize); + static_assert(XSrcVectorSize == BetaSrcVectorSize); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using BlockwiseSumReduce = PartitionedBlockwiseReduction; + + using ThreadwiseSumReduce = ThreadwiseReduction; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& beta_grid_desc_m_k, + const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) + { + // LDS + __shared__ ComputeDataType p_reduce_work_buffer[BlockSize]; + + auto reduce_work_buf = + make_dynamic_buffer(p_reduce_work_buffer, BlockSize); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto& beta_thread_buf = gamma_thread_buf; + + auto y_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto& x_square_thread_buf = y_thread_buf; + + StaticBuffer + mean_thread_buf; + StaticBuffer + mean_square_thread_buf; + StaticBuffer& + var_thread_buf = mean_square_thread_buf; + StaticBuffer& + inv_std_thread_buf = mean_square_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * BetaSrcVectorSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + // E(x), E[x^2], var(x) + // FIXME: Should not hack the transform from deviceOP + ComputeDataType reduce_length = type_convert( + x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = reduce::Add::template GetIdentityValue(); + mean_square_thread_buf(I) = reduce::Add::template GetIdentityValue(); + }); + + // Separate sweep once and sweep twice pipeline + if constexpr(SweepOnce) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); + }); + }); + + ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); + ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf); + + if constexpr(i != ThreadBufferNumber - 1) + { + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); + mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; + + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); + mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; + + // var(x) = E[x^2] - E[x]^2 + var_thread_buf(I) = + mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); + }); + + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); + + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + } // end of sweep once + else + { + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + x_square_thread_buf(i)(Number{}) = + x_thread_buf(i)(Number{}) * + x_thread_buf(i)(Number{}); + }); + }); + + ThreadwiseSumReduce::Reduce(x_thread_buf[i], mean_thread_buf); + ThreadwiseSumReduce::Reduce(x_square_thread_buf[i], mean_square_thread_buf); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_thread_buf(I)); + mean_thread_buf(I) = mean_thread_buf(I) / reduce_length; + + block_sync_lds(); + + BlockwiseSumReduce::Reduce(reduce_work_buf, mean_square_thread_buf(I)); + mean_square_thread_buf(I) = mean_square_thread_buf(I) / reduce_length; + + // var(x) = E[x^2] - E[x]^2 + var_thread_buf(I) = + mean_square_thread_buf(I) - (mean_thread_buf(I) * mean_thread_buf(I)); + + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); + }); + + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + } + } // end of sweep twice + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp new file mode 100755 index 000000000000..8157b4fbc39c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp @@ -0,0 +1,229 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_naive_variance.hpp" +#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp" + +namespace ck { +template +__global__ void +kernel_normalization(const GridDesc_M_K x_grid_desc_m_k, + const GridDesc_M_K gamma_grid_desc_m_k, + const GridDesc_M_K beta_grid_desc_m_k, + const GridDesc_M_K y_grid_desc_m_k, + const GridDesc_M save_mean_grid_desc_m, + const GridDesc_M save_inv_std_grid_desc_m, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) +{ + GridwiseReduction::Run(x_grid_desc_m_k, + gamma_grid_desc_m_k, + beta_grid_desc_m_k, + y_grid_desc_m_k, + save_mean_grid_desc_m, + save_inv_std_grid_desc_m, + num_k_block_tile_iteration, + epsilon, + p_x_global, + p_gamma_global, + p_beta_global, + p_y_global, + p_save_mean_global, + p_save_inv_std_global, + y_elementwise_op); +}; + +template +auto NormalizationKernelSelector(bool isSweepOnce) +{ + using GridwiseNormalizationGenericNaive = + GridwiseNormalizationNaiveVariance_mk_to_mk; + using GridwiseNormalizationSweepOnceNaive = + GridwiseNormalizationNaiveVariance_mk_to_mk; + using GridwiseNormalizationGenericWelford = + GridwiseNormalizationWelfordVariance_mk_to_mk; + using GridwiseNormalizationSweepOnceWelford = + GridwiseNormalizationWelfordVariance_mk_to_mk; + + if constexpr(UseWelford) + { + return isSweepOnce ? kernel_normalization + : kernel_normalization; + } + else + { + return isSweepOnce ? kernel_normalization + : kernel_normalization; + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp new file mode 100755 index 000000000000..80e9a84f9681 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp @@ -0,0 +1,256 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseNormalizationSplitK1st +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static int + GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id) + { + bool is_rightmost_block = block_k_cluster_id == kGridSize - 1; + + if(is_rightmost_block) + { + int left_kPerBlock = math::integer_divide_ceil(k, kGridSize); + int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1); + int kPerThread = kRightmostBlock < K_BlockTileSize + ? 0 + : KThreadSliceSize * (kRightmostBlock / K_BlockTileSize); + int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + else + { + int kPerBlock = math::integer_divide_ceil(k, kGridSize); + return KThreadSliceSize * (kPerBlock / K_BlockTileSize); + } + } + + // Calculate mean and variance by welford along k dimension + __device__ static void Run(const XGridDesc_M_K& x_grid_desc_m_k, + const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + index_t num_k_block_tile_iteration, + const XDataType* const __restrict__ p_x_global, + MeanVarDataType* const p_mean_global, + MeanVarDataType* const p_variance_global, + int32_t* const p_welford_count_global) + { + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const index_t k_grid_size = mean_var_grid_desc_m_kblock.GetLength(I1); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * reduceSizePerBlock + thread_k_cluster_id * XSrcVectorSize)); + + auto mean_var_count_store_index = make_multi_index( + block_m_cluster_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id); + + auto threadwise_welford_mean_var_store = + ThreadwiseTensorSliceTransfer_v1r3, + 1, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>( + mean_var_grid_desc_m_kblock, mean_var_count_store_index, PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1), + kRaw, + k_grid_size, + block_k_cluster_id, + thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + int welford_count = 0; + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + + // The value of count is same for all I + if constexpr(I == MThreadSliceSize - 1) + welford_count = count; + }); + + if(thread_k_cluster_id == 0) + { + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + mean_thread_buf, + mean_var_grid_desc_m_kblock, + mean_global_val_buf); + + threadwise_welford_mean_var_store.Run(thread_buffer_desc_m_1, + make_tuple(I0, I0), + var_thread_buf, + mean_var_grid_desc_m_kblock, + var_global_val_buf); + + if(block_m_cluster_id == 0 && thread_m_cluster_id == 0) + p_welford_count_global[block_k_cluster_id] = welford_count; + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp new file mode 100755 index 000000000000..9e380f963801 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_2nd.hpp @@ -0,0 +1,499 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/utility/math.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +struct GridwiseNormalizationSplitK2nd +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + static_assert(XSrcVectorSize == YDstVectorSize); + static_assert(XSrcVectorSize == GammaSrcVectorSize); + static_assert(XSrcVectorSize == BetaSrcVectorSize); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using ThreadBufferLengths_M_1 = Sequence; + static constexpr auto thread_buffer_desc_m_1 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1)); + + using ThreadWelfordSrcDesc_M_1 = decltype(thread_buffer_desc_m_1); + using ThreadWelfordDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelfordMerge; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static void Run(const MeanVarGridDesc_M_KBlock& mean_var_grid_desc_m_kblock, + const CountGridDesc_M_KBlock& count_grid_desc_m_kblock, + const XYGammaBetaGridDesc_M_K& x_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& gamma_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& beta_grid_desc_m_k, + const XYGammaBetaGridDesc_M_K& y_grid_desc_m_k, + const SaveMeanInvStdGridDesc_M& save_mean_grid_desc_m, + const SaveMeanInvStdGridDesc_M& save_inv_std_grid_desc_m, + index_t num_k_mean_var_count_iteration, + index_t num_k_block_tile_iteration, + index_t k_grid_size, + ComputeDataType epsilon, + const MeanVarDataType* const p_mean_global, + const MeanVarDataType* const p_variance_global, + const int32_t* const p_welford_count_global, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) + { + // Thread/Block id + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + const index_t block_m_cluster_id = block_global_id / k_grid_size; + const index_t block_k_cluster_id = block_global_id % k_grid_size; + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + // Global Memory + const auto mean_global_val_buf = make_dynamic_buffer( + p_mean_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto var_global_val_buf = make_dynamic_buffer( + p_variance_global, mean_var_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto welford_count_global_val_buf = make_dynamic_buffer( + p_welford_count_global, count_grid_desc_m_kblock.GetElementSpaceSize()); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + + // VGPR + StaticBuffer + in_mean_thread_buf; + StaticBuffer + in_var_thread_buf; + StaticBuffer + in_welford_count_thread_buf; + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + StaticBuffer + welford_count_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; + + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto& beta_thread_buf = gamma_thread_buf; + auto& y_thread_buf = x_thread_buf; + + // IO + auto threadwise_mean_var_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + mean_var_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_count_load_m_kblock = + ThreadwiseTensorSliceTransfer_v2, + 1, + 1, + 1, + true>( + count_grid_desc_m_kblock, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id)); + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * BetaSrcVectorSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + block_k_cluster_id * K_BlockTileSize * num_k_block_tile_iteration + + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_m_cluster_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + // step1: Merge mean and variance + constexpr auto mean_var_count_thread_copy_step_I0_k = + make_multi_index(I0, KThreadClusterSize); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + welford_count_thread_buf(I) = 0; + }); + + for(index_t k = 0; k < num_k_mean_var_count_iteration; ++k) + { + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + mean_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_mean_thread_buf); + + threadwise_mean_var_load_m_kblock.Run(mean_var_grid_desc_m_kblock, + var_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_var_thread_buf); + + threadwise_count_load_m_kblock.Run(count_grid_desc_m_kblock, + welford_count_global_val_buf, + thread_buffer_desc_m_1, + make_tuple(I0, I0), + in_welford_count_thread_buf); + + ThreadwiseWelford::Run(in_mean_thread_buf, + in_var_thread_buf, + in_welford_count_thread_buf, + mean_thread_buf, + var_thread_buf, + welford_count_thread_buf); + + threadwise_mean_var_load_m_kblock.MoveSrcSliceWindow( + mean_var_grid_desc_m_kblock, mean_var_count_thread_copy_step_I0_k); + threadwise_count_load_m_kblock.MoveSrcSliceWindow(count_grid_desc_m_kblock, + mean_var_count_thread_copy_step_I0_k); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + BlockwiseWelford::Run( + mean_thread_buf(I), var_thread_buf(I), welford_count_thread_buf(I)); + + inv_std_thread_buf(I) = + type_convert(1.0f) / ck::math::sqrt(var_thread_buf(I) + epsilon); + }); + + // step2: save mean and inverse std for backward (optional) + if(block_k_cluster_id == 0 && thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // step3: normalization + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + + for(index_t k = 0; k < num_k_block_tile_iteration; ++k) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + } // end for (normalization) + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp new file mode 100755 index 000000000000..15b412fba4cc --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_welford_variance.hpp @@ -0,0 +1,564 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/data_type.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_welford.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +// Y = Normalization(X, Beta, Gamma) +template +struct GridwiseNormalizationWelfordVariance_mk_to_mk +{ + static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || + (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) || + (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0), + "Invalid thread slice sizes and/or vector sizes configuration, please check!"); + + static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0, + "Invalid thread slice sizes and/or save mean and inverse std vector sizes " + "configuration, please check!"); + + static_assert(XSrcVectorSize == YDstVectorSize); + static_assert(XSrcVectorSize == GammaSrcVectorSize); + static_assert(XSrcVectorSize == BetaSrcVectorSize); + + static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); + + using ThreadClusterLengths_M_K = Sequence; + + using ThreadBufferDimAccessOrder = + typename conditional, Sequence<0, 1>>::type; + + using ThreadClusterArrangeOrder = + typename conditional, Sequence<0, 1>>::type; + + static constexpr auto thread_cluster_desc = + make_cluster_descriptor(ThreadClusterLengths_M_K{}, ThreadClusterArrangeOrder{}); + + using ThreadBufferLengths_M_K = Sequence; + static constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + using ThreadBufferLengths_M = Sequence; + static constexpr auto thread_buffer_desc_m = + make_naive_tensor_descriptor_packed(make_tuple(Number{})); + + using ThreadReduceSrcDesc_M_K = decltype(make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}))); + using ThreadReduceDstDesc_M = + decltype(make_naive_tensor_descriptor_packed(make_tuple(Number{}))); + + using ThreadwiseWelford = + ThreadwiseWelford; + + using BlockwiseWelford = BlockwiseWelford; + + using PassThroughOp = tensor_operation::element_wise::PassThrough; + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; + static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; + static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize; + + static constexpr auto ThreadBufferNumber = Number{}; + + __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k, + int thread_k_cluster_id) + { + // FIXME: Should not hack the transform from deviceOP + int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0]; + int kPerThread = + kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize); + int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize; + + if(kPerBlockTail > 0) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + int thread_max_len = + (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i; + int delta = thread_max_len - kPerBlockTail; + delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize); + kPerThread += XSrcVectorSize - delta; + }); + } + + return kPerThread; + } + + __device__ static void Run(const GridDesc_M_K& x_grid_desc_m_k, + const GridDesc_M_K& gamma_grid_desc_m_k, + const GridDesc_M_K& beta_grid_desc_m_k, + const GridDesc_M_K& y_grid_desc_m_k, + const GridDesc_M& save_mean_grid_desc_m, + const GridDesc_M& save_inv_std_grid_desc_m, + index_t num_k_block_tile_iteration, + ComputeDataType epsilon, + const XDataType* const __restrict__ p_x_global, + const GammaDataType* const __restrict__ p_gamma_global, + const BetaDataType* const __restrict__ p_beta_global, + YDataType* const __restrict__ p_y_global, + SaveMeanInvStdDataType* const __restrict__ p_save_mean_global, + SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global, + const YElementwiseOperation y_elementwise_op) + { + auto x_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto gamma_thread_buf = generate_tuple( + [&](auto) { + return StaticBuffer{}; + }, + Number{}); + + auto& beta_thread_buf = gamma_thread_buf; + auto& y_thread_buf = x_thread_buf; + + StaticBuffer + mean_thread_buf; + StaticBuffer + var_thread_buf; + auto& inv_std_thread_buf = var_thread_buf; + + const index_t thread_local_id = get_thread_local_1d_id(); + const index_t block_global_id = get_block_1d_id(); + + const auto thread_cluster_idx = + thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id)); + + const auto thread_m_cluster_id = thread_cluster_idx[I0]; + const auto thread_k_cluster_id = thread_cluster_idx[I1]; + + auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2( + x_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * XSrcVectorSize)); + + auto threadwise_gamma_load = + ThreadwiseTensorSliceTransfer_v2( + gamma_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * GammaSrcVectorSize)); + + auto threadwise_beta_load = + ThreadwiseTensorSliceTransfer_v2( + beta_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * BetaSrcVectorSize)); + + auto threadwise_y_store = + ThreadwiseTensorSliceTransfer_v1r3( + y_grid_desc_m_k, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize, + thread_k_cluster_id * YDstVectorSize), + y_elementwise_op); + + auto threadwise_mean_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_mean_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + auto threadwise_inv_std_store = + ThreadwiseTensorSliceTransfer_v1r3, // DimAccessOrder + 0, // SrcVectorDim + SaveMeanInvStdDstVectorSize, // ScalarPerVector + InMemoryDataOperationEnum::Set, + 1, + true>( + save_inv_std_grid_desc_m, + make_multi_index(block_global_id * M_BlockTileSize + + thread_m_cluster_id * MThreadSliceSize), + PassThroughOp{}); + + constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize); + constexpr auto thread_copy_bwd_step_m_k = + make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize); + + const auto x_global_val_buf = make_dynamic_buffer( + p_x_global, x_grid_desc_m_k.GetElementSpaceSize()); + + const auto gamma_global_val_buf = make_dynamic_buffer( + p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize()); + + const auto beta_global_val_buf = make_dynamic_buffer( + p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize()); + + auto y_global_val_buf = make_dynamic_buffer( + p_y_global, y_grid_desc_m_k.GetElementSpaceSize()); + + auto save_mean_global_val_buf = make_dynamic_buffer( + p_save_mean_global, save_mean_grid_desc_m.GetElementSpaceSize()); + + auto save_inv_std_global_val_buf = make_dynamic_buffer( + p_save_inv_std_global, save_inv_std_grid_desc_m.GetElementSpaceSize()); + + auto threadwise_welford = ThreadwiseWelford(); + threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + mean_thread_buf(I) = type_convert(0.0f); + var_thread_buf(I) = type_convert(0.0f); + }); + + // Separate sweep once and sweep twice pipeline + if constexpr(SweepOnce) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + + if constexpr(i != ThreadBufferNumber - 1) + { + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + } + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = type_convert(1.0f) / + ck::math::sqrt(var_thread_buf(I) + epsilon); + }); + + // save mean and inverse std for backward (optional) + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + // normalization + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); + + // gamma & beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + + if constexpr(i != ThreadBufferNumber - 1) + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + } // end of sweep once + else + { + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + threadwise_welford.Run(x_thread_buf[i], mean_thread_buf, var_thread_buf); + }); + } + + static_for<0, MThreadSliceSize, 1>{}([&](auto I) { + if constexpr(I > 0) + block_sync_lds(); + + int count = threadwise_welford.cur_count_; + BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count); + inv_std_thread_buf(I) = 1 / ck::math::sqrt(var_thread_buf(I) + epsilon); + }); + + if(thread_k_cluster_id == 0) + { + if(p_save_mean_global != nullptr) + { + threadwise_mean_store.Run(thread_buffer_desc_m, + make_tuple(I0), + mean_thread_buf, + save_mean_grid_desc_m, + save_mean_global_val_buf); + } + if(p_save_inv_std_global != nullptr) + { + threadwise_inv_std_store.Run(thread_buffer_desc_m, + make_tuple(I0), + inv_std_thread_buf, + save_inv_std_grid_desc_m, + save_inv_std_global_val_buf); + } + } + + auto thread_copy_tail_m_k = + (num_k_block_tile_iteration - 1) * ThreadBufferNumber * thread_copy_fwd_step_m_k; + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k); + + for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles) + { + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_x_load.Run(x_grid_desc_m_k, + x_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + x_thread_buf(i)); + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_gamma_load.Run(gamma_grid_desc_m_k, + gamma_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + gamma_thread_buf(i)); + + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // normalize + y_thread_buf(iK0)(Number{}) = + (x_thread_buf(iK0)(Number{}) - mean_thread_buf(iM)) * + inv_std_thread_buf(iM); + + // gamma + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) * + gamma_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_beta_load.Run(beta_grid_desc_m_k, + beta_global_val_buf, + thread_buffer_desc_m_k, + make_tuple(I0, I0), + beta_thread_buf(i)); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { + static_for<0, ThreadBufferNumber, 1>{}([&](auto iK0) { + static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) { + constexpr auto offset_m_k = + thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1)); + + // beta + y_thread_buf(iK0)(Number{}) = + y_thread_buf(iK0)(Number{}) + + beta_thread_buf(iK0)(Number{}); + }); + }); + }); + + static_for<0, ThreadBufferNumber, 1>{}([&](auto i) { + threadwise_y_store.Run(thread_buffer_desc_m_k, + make_tuple(I0, I0), + y_thread_buf(i), + y_grid_desc_m_k, + y_global_val_buf); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + thread_copy_fwd_step_m_k); + }); + + threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k); + threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, + 2 * thread_copy_bwd_step_m_k); + } + } // end of sweep twice + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp new file mode 100755 index 000000000000..c6eecc067d9f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/reduction_functions_threadwise.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/reduction_functions_accumulate.hpp" + +namespace ck { + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template > +struct ThreadwiseReduction +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + using Op = OpReduce; + + template + __device__ static void Reduce(const SrcBufferType& src_buf, DstBufferType& dst_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_buf(Number{}), src_buf[Number{}]); + }); + }); + }; +}; + +// Assume +// 1) SrcDesc is known at compile-time +// 2) DstDesc is known at compile-time +// 3) SrcBuffer is static buffer +// 4) DstBuffer is static buffer +template < + typename AccDataType, + typename IndexDataType, + typename SrcThreadDesc_M_K, + typename DstThreadDesc_M, + typename OpReduce, + bool PropagateNan, + typename Accumulation = + detail::AccumulateWithIndexAndNanCheck> +struct ThreadwiseReductionWithIndex +{ + static constexpr auto src_thread_desc_m_k = SrcThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + template + __device__ static void Reduce(const SrcValueBufferType& src_val_buf, + const SrcIndexBufferType& src_idx_buf, + DstValueBufferType& dst_val_buf, + DstIndexBufferType& dst_idx_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = dst_thread_desc_m.CalculateOffset(make_tuple(iM)); + + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Accumulation::Calculate(dst_val_buf(Number{}), + src_val_buf[Number{}], + dst_idx_buf(Number{}), + src_idx_buf[Number{}]); + }); + }); + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp new file mode 100755 index 000000000000..44730d551c5c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp @@ -0,0 +1,224 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +// C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1 +{ + __device__ constexpr ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr auto TK = TKLengths{}[I0]; + constexpr auto TM0 = TMLengths{}[I0]; + constexpr auto TM1 = TMLengths{}[I1]; + constexpr auto TN0 = TNLengths{}[I0]; + constexpr auto TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK, 1>{}([&](auto tk) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk, tm0, tm1)); + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk, tn0, tn1)); + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +// C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1] +// Tensor element can be vectorized data +// Assume: +// 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are +// known at compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 +{ + __device__ constexpr ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() + { + static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && + BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && + CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, + // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths + + // TODO remove this restriction + static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2, + "wrong!"); + } + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + constexpr index_t TK0 = TKLengths{}[I0]; + constexpr index_t TK1 = TKLengths{}[I1]; + constexpr index_t TM0 = TMLengths{}[I0]; + constexpr index_t TM1 = TMLengths{}[I1]; + constexpr index_t TN0 = TNLengths{}[I0]; + constexpr index_t TN1 = TNLengths{}[I1]; + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + static_for<0, TK0, 1>{}([&](auto tk0) { + static_for<0, TM0, 1>{}([&](auto tm0) { + static_for<0, TM1, 1>{}([&](auto tm1) { + static_for<0, TN0, 1>{}([&](auto tn0) { + static_for<0, TN1, 1>{}([&](auto tn1) { + vector_type a_vec; + vector_type b_vec; + + static_for<0, TK1, 1>{}([&](auto tk1) { + constexpr index_t a_offset = + AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset( + a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1)); + + constexpr index_t b_offset = + BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset( + b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1)); + + a_vec.template AsType()(tk1) = a_buf[Number{}]; + b_vec.template AsType()(tk1) = b_buf[Number{}]; + }); + + using a_vector_t = typename vector_type::type; + using b_vector_t = typename vector_type::type; + + constexpr index_t c_offset = + CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset( + c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1)); + + inner_product( + a_vec.template AsType()[I0], + b_vec.template AsType()[I0], + c_buf(Number{})); + }); + }); + }); + }); + }); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp new file mode 100755 index 000000000000..e97aa433a6f3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_gemm_dlops_v3.hpp @@ -0,0 +1,168 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP +#define CK_THREADWISE_GEMM_DLOPS_V3_HPP + +#include "common_header.hpp" +#include "math.hpp" + +namespace ck { + +// C[M, N] += transpose(A[K, M]) * B[K, N] +// Element of matrix can be vectorized data +// Assume: +// 1. AThreadDesc_E1_K_E2, BThreadDesc_E1_N_Ho_Wo_E2, CThreadDesc_K_N_Ho_Wo are known at +// compile-time +// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time +template ::type = false> +struct ThreadwiseGemmDlops_km_kn_mn_v3 +{ + + template + __device__ static void Run(const ABuffer& a_buf, + AOriginIdx, + const BBuffer& b_buf, + BOriginIdx, + CBuffer& c_buf, + COriginIdx) + { + + static_assert(AThreadDesc_E1_K_E2::IsKnownAtCompileTime() && + BThreadDesc_E1_N_Ho_Wo_E2::IsKnownAtCompileTime() && + CThreadDesc_K_N_Ho_Wo::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto E1 = AThreadDesc_E1_K_E2{}.GetLength(I0); + constexpr auto K = AThreadDesc_E1_K_E2{}.GetLength(I1); + constexpr auto E2 = AThreadDesc_E1_K_E2{}.GetLength(I2); + + constexpr auto Ho = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I2); + constexpr auto Wo = BThreadDesc_E1_N_Ho_Wo_E2{}.GetLength(I3); + + constexpr auto a_origin_idx = to_multi_index(AOriginIdx{}); + constexpr auto b_origin_idx = to_multi_index(BOriginIdx{}); + constexpr auto c_origin_idx = to_multi_index(COriginIdx{}); + + if constexpr((Ho % 2 == 0) && (Wo % 2 == 0)) + { + constexpr auto SubHW = 2; + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, SubHW>{}([&](auto h) { + static_for<0, Wo, SubHW>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b0_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t b1_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w + 1, e2)); + + constexpr index_t b2_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w, e2)); + + constexpr index_t b3_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h + 1, w + 1, e2)); + + constexpr index_t c0_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + constexpr index_t c1_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h, w + 1)); + + constexpr index_t c2_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w)); + + constexpr index_t c3_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset( + c_origin_idx + make_tuple(k, 0, h + 1, w + 1)); + + amd_assembly_outer_product_1x4(a_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + b_buf[Number{}], + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{}), + c_buf(Number{})); + }); + }); + }); + }); + }); + } + else + { + + static_for<0, K, 1>{}([&](auto k) { + static_for<0, Ho, 1>{}([&](auto h) { + static_for<0, Wo, 1>{}([&](auto w) { + static_for<0, E1, 1>{}([&](auto e1) { + static_for<0, E2, 1>{}([&](auto e2) { + constexpr index_t a_offset = AThreadDesc_E1_K_E2{}.CalculateOffset( + a_origin_idx + make_tuple(e1, k, e2)); + + constexpr index_t b_offset = + BThreadDesc_E1_N_Ho_Wo_E2{}.CalculateOffset( + b_origin_idx + make_tuple(e1, 0, h, w, e2)); + + constexpr index_t c_offset = + CThreadDesc_K_N_Ho_Wo{}.CalculateOffset(c_origin_idx + + make_tuple(k, 0, h, w)); + + inner_product(a_buf[Number{}], + b_buf[Number{}], + c_buf(Number{})); + }); + }); + }); + }); + }); + } + } +}; + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp new file mode 100755 index 000000000000..6774a35bcb87 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { + +// Assume: +// 1. Desc is known at compile-time +// 2. Buffer is StaticBuffer +// 3. OriginIdx is known at compile-time +// 4. use #-step +template ::type = false> +struct ThreadwiseTensorSliceSet_v1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + template + __device__ void Run(const Desc&, const OriginIdx&, Buffer& buf, const Data& initial_value) const + { + static_assert(Desc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value, + "wrong! OriginIdx need to be known at compile-time"); + + // Desc is known at compile-time + constexpr auto desc = remove_cvref_t{}; + + // OriginIdx is known at compile-time + constexpr auto origin_idx = to_multi_index(OriginIdx{}); + + static_ford{}([&](auto access_idx) { + constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx); + + constexpr bool is_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); + + constexpr index_t offset = coord.GetOffset(); + + if constexpr(is_valid) + { + buf(Number{}) = initial_value; + } + }); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp new file mode 100755 index 000000000000..4e4c92de407a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -0,0 +1,2102 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is StaticBuffer +// 3. SrcSliceOrginIdx is known at compile-time +// 2. dst: +// 1. DstDesc is not known at compile-time +// 2. DstBuffer is DynamicBuffer +// 3. DstSliceOrginIdx is not known at compile time +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v1r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc, + const Index& dst_slice_origin_idx, + const ElementwiseOperation& element_op) + : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)), + element_op_{element_op} + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + static_assert(SrcDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector? + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + typename vector_type_maker::type dst_vector; + using dst_vector_t = typename vector_type_maker::type::type; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve? + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + DstData v; + + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + + dst_vector.template AsType()(i) = v; + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; // namespace ThreadwiseTensorSliceTransfer_v1r3 + +/** + * @brief Helper structure that facilitates transfer of source (grid) data to destination threads. + * + * @details The following assumptions are made: + * - For Source (Grid) Data: + * 1. The source tensor descriptor SrcDesc is not known at compile-time. + * 2. The source buffer is a dynamic buffer. + * 3. The source slice origin index src_slice_origin_idx is not known at compile-time. + * - For Destination (Thread) Data: + * 1. The destination tensor descriptor DstDesc is known at compile-time. + * 2. The destination buffer dst_buf is a static buffer. + * 3. The destination slice origin index dst_slice_origin_idx is known at compile-time. + * + * @tparam SrcData The data type of the source tensor. + * @tparam DstData The data type of the destination tensor. + * @tparam SrcDesc The descriptor type of the source tensor. + * @tparam DstDesc The descriptor type of the destination tensor. + * @tparam SliceLengths The lengths of the slice to be transferred. + * @tparam DimAccessOrder The order of dimension access for the space-filling curve. + * @tparam SrcVectorDim The dimension along which vectorized access is performed in the source + * tensor. + * @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor. + * @tparam SrcScalarStrideInVector Not used. + * @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run + * or rolled back one step in MoveSrcSliceWindow + * @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for + * floating-point types). + * + */ +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v2 +{ + static_assert((InvalidElementAsNaN && !ck::is_integral::value) || + (!InvalidElementAsNaN), + "Filling invalid element as NaN is only for floating point types"); + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, + const Index& src_slice_origin_idx) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + + if constexpr(is_same_v, pk_i4_t> || + is_same_v, f4x2_pk_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over tensor and copy + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset() / PackedSize, + is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx + + i * src_scalar_step_in_vector); + + if constexpr(InvalidElementAsNaN) + { + dst_buf(Number{}) = + is_src_valid + ? type_convert(src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } + }); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + } + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; +}; // namespace ck + +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v2_gather +{ + static_assert((InvalidElementAsNaN && !ck::is_integral::value) || + (!InvalidElementAsNaN), + "Filling invalid element as NaN is only for floating point types"); + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather( + const SrcDesc& src_desc, + const Index& src_slice_origin_idx, + const StaticallyIndexedArray& scale_gather_offsets) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)), + scale_gather_offsets_(scale_gather_offsets) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc need to known at compile-time"); + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + + if constexpr(is_same_v, pk_i4_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + auto adjusted_origin_idx = [&]() { + Index idx; + + static_for<0, nDim, 1>{}( + [&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number{}]; }); + + return idx; + }(); + + src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) + { + static_assert(DstDesc::IsKnownAtCompileTime(), + "wrong! DstDesc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); + + // DstDesc and dst_slice_origin_idx are known at compile-time + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over tensor and copy + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) { + constexpr auto current_dst_origin = + to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type + src_vector; + + using src_vector_t = + typename vector_type_maker::type::type; + constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, + src_coord_); + + // copy data from src_buf into src_vector + src_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset() / PackedSize + + scale_gather_offsets_(gather_idx), + is_src_valid); + + // copy data from src_vector into dst_buf + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t dst_offset = + dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + + src_data_idx + i * src_scalar_step_in_vector); + constexpr auto full_dst_offset = + dst_desc.CalculateOffset(current_dst_origin) + dst_offset; + + if constexpr(InvalidElementAsNaN) + { + dst_buf(full_dst_offset) = + is_src_valid + ? type_convert(src_vector.template AsType()[i]) + : NumericLimits::QuietNaN(); + } + else + { + dst_buf(Number{}) = + type_convert(src_vector.template AsType()[i]); + } + }); + + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + StaticallyIndexedArray scale_gather_offsets_; +}; // namespace ck + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! Not divisible"); + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_tmp_vector + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_tmp_vector to buffer_ + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector); + + buffer_(Number{}) = src_tmp_vector.template AsType()[i]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + vector_type_maker_t dst_tmp_vector; + + // copy data from buffer_ to dst_tmp_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector); + + dst_tmp_vector.template AsType()(i) = + type_convert(buffer_[Number{}]); + }); + + using dst_vector_t = typename decltype(dst_tmp_vector)::type; + + // copy data from dst_tmp_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_tmp_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_step_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_step_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto I0 = Number<0>{}; + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + if constexpr(is_same_v, pk_i4_t> || + is_same_v, f4x2_pk_t>) + { + static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1"); + } + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_tmp_vector + if constexpr(SrcBuffer::IsDynamicBuffer()) + { + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset() / PackedSize, + is_src_valid); + } + else if constexpr(SrcBuffer::IsStaticBuffer()) + { + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_ref_to_origin_disp_idx + data_to_origin_disp_idx + + i * src_scalar_step_in_vector); + + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); + } + + if constexpr(is_same, pk_i4_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 8; + + static_assert(SrcScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 2; + + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + }); + } + + // Fuse scale + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstData& scale, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // scalar per access of each dim + constexpr auto src_scalar_per_access = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number{}; + } + else + { + return Number<1>{}; + } + }, + Number{}); + + // scalar step (if steping on SrcVectorDim) of each dim + constexpr auto src_scalar_step_in_vector = generate_sequence_v2( + [&](auto i) constexpr { + if constexpr(i == SrcVectorDim) + { + return Number<1>{}; + } + else + { + return Number<0>{}; + } + }, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { +#if 0 + // TODO: unable to compile + // position in slice window + constexpr auto data_to_origin_disp_idx = + container_reorder_given_old2new(ordered_access_idx, dim_access_order) * + src_scalar_per_access; +#else + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access; +#endif + // src coordinate + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_tmp_vector; + + using src_vector_t = typename decltype(src_tmp_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_tmp_vector + if constexpr(SrcBuffer::IsDynamicBuffer()) + { + src_tmp_vector.template AsType()(Number<0>{}) = + src_buf.template Get(src_data_coord.GetOffset() / PackedSize, + is_src_valid); + } + else if constexpr(SrcBuffer::IsStaticBuffer()) + { + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_ref_to_origin_disp_idx + data_to_origin_disp_idx + + i * src_scalar_step_in_vector); + + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); + } + + if constexpr(is_same, pk_i4_t>::value) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + vector_type scale_vector; + scale_vector.template AsType()(Number<0>{}) = scale; + scale_vector.template AsType()(Number<1>{}) = scale; + + constexpr index_t pack_size = 8; + + static_assert(SrcScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + using scale_v_t = typename vector_type_maker_t::type; + + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::DequantPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i], + scale_vector.template AsType()[Number<0>{}]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else if constexpr(is_same, f8_t>::value && + is_same, half_t>::value && + SrcScalarPerVector % 2 == 0) + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 2; + + using dst_v_t = typename vector_type_maker_t::type; + using src_v_t = typename vector_type_maker_t::type; + static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack2{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + else + { + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + // TODO: if SrcData and DstData are vetor type, then static_cast may not compile + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + dst_tmp_vector.template AsType()(i) = + type_convert(src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + } + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + __device__ void SetSrcCoord(const Index& src_ref_idx) + { + src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx); + } + + private: + SrcCoord src_ref_coord_; +}; + +/** + * @brief Threadwise data transfer + * + * Do NOT involve any tensor coordinates with StaticBuffer + * + */ +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic( + const ElementwiseOperation& element_op) + : element_op_{element_op} + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + if constexpr(is_same, pk_i4_t>::value) + { + static_for<0, num_access, 1>{}([&](auto idx_1d) { + typename vector_type_maker::type + src_tmp_vector; + + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + src_tmp_vector.template AsType()(i) = src_buf[Number{}]; + }); + + // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to + // DstData) + vector_type_maker_t dst_tmp_vector; + + constexpr index_t pack_size = 8; + + static_assert(DstScalarPerVector % pack_size == 0, ""); + + using src_v_t = typename vector_type_maker_t::type; + using dst_v_t = typename vector_type_maker_t::type; + + static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) { + ck::tensor_operation::element_wise::PassThroughPack8{}( + dst_tmp_vector.template AsType()(i), + src_tmp_vector.template AsType()[i]); + }); + + // copy data from dst_tmp_vector into dst_buf + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + dst_buf(Number{}) = dst_tmp_vector.template AsType()[i]; + }); + }); + } + else + { + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + DstData v; + + // apply element-wise operation + element_op_(v, src_buf[Number{}]); + + // apply type convert + dst_buf(Number{}) = v; + }); + }); + } + } + + ElementwiseOperation element_op_; +}; + +// Specialized for gfx11 +// A single Wave32 is composed by double row +// Data exchange allowed between these two rows +// This RowLane Dst buf will be filled from two Src buf +// SrcA: From specific thread buffer hold by This RowLane on This Row +// SrcB: From specific thread buffer hold by This RowLane on The other Row +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row, v_theother_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply inter-row permute. + temp = __builtin_amdgcn_permlanex16(temp, + type_convert_sp(v_this_row), + LowEightRowlaneIdx, + HighEightRowLaneIdx, + 1, + 0); + v_theother_row = type_convert_sp(temp); + + if(get_thread_local_1d_id() % 32 < 16) + { + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + dst_buf(Number{}) = + type_convert_sp(v_theother_row); + } + else + { + // apply type convert + dst_buf(Number{}) = + type_convert_sp(v_this_row); + dst_buf(Number{}) = type_convert_sp(v_theother_row); + } + }); + }); + } + ElementwiseOperation element_op_{}; +}; + +// Specialized for gfx12 +template ::type = false> +struct ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + __device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index& src_idx) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! Not divisible"); + ignore = src_idx; + } + + template + __device__ void Run(const SrcDesc&, + const SrcSliceOriginIdx&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstSliceOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! Desc need to known at compile-time"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SliceOrigin need to known at compile-time"); + + static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(), + "wrong! Buffer need to be StaticBuffer"); + + // SrcDesc and src_slice_origin_idx are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); + constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{}); + + // scalar per access on each dim + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = + generate_sequence(detail::lambda_scalar_step_in_vector{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector, + "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"); + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d); + + // copy data from src_buf into dst_vector + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + // src_desc error, non constexpr, caused by merge transform + constexpr index_t src_offset = src_desc.CalculateOffset( + src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector); + + SrcData v_this_row; + // int type temp value due to intrinsic requirement + int temp = 0; + + // apply element-wise operation + element_op_(v_this_row, src_buf[Number{}]); + + // apply intra-row permute. + if constexpr(IntraRowSwizzlePerm) + { + temp = __builtin_amdgcn_permlane16( + temp, type_convert_sp(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0); + v_this_row = type_convert_sp(temp); + } + + // apply type convert + dst_buf(Number{}) = type_convert_sp(v_this_row); + }); + }); + } + ElementwiseOperation element_op_{}; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp new file mode 100755 index 000000000000..168f028e2a8a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? ScalarPerVector : 1; + } +}; + +template +struct lambda_scalar_step_in_vector +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + return (i == VectorDim) ? 1 : 0; + } +}; + +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +template +struct lambda_wave_cluster_dimension +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if((nDim - i) == 3) + return WaveNum; + else + return 1; + } +}; + +} // namespace detail + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp new file mode 100755 index 000000000000..79e22018a621 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp @@ -0,0 +1,975 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I10 = Number<10>{}; + static constexpr auto I12 = Number<12>{}; + static constexpr auto I13 = Number<13>{}; + static constexpr auto I14 = Number<14>{}; + static constexpr auto I16 = Number<16>{}; + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr auto SrcScalarPerVector = Number{}; + static constexpr auto DstScalarPerVector = Number{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op) + { + if constexpr((packed_size_v) > 1) + { + static_assert(is_same_v, remove_cvref_t>, + "SrcData != DstData"); + + static_assert( + SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); + + static_assert(SrcVectorDim == DstVectorDim, + "Packed data type does not support transpose"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + // maintain a container record is_src_valid, waiting for RunWrite use. + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + src_oob_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, is_src_valid); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + dst_vector_type op_r_v; + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + else + { + return 1; + } + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + using src_elem_op_vec_t = typename vector_type::type; + using dst_elem_op_vec_t = typename vector_type::type; + + using VectorSizeLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + using VectorOffsetsLookupTable = Tuple, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence, + Sequence>; + + static_for<0, tuple_element_t::Size(), 1>{}( + [&](auto v_idx) { + constexpr auto VectorLoadSize = + tuple_element_t::At(v_idx); + constexpr auto LoadOffset = + tuple_element_t::At(v_idx); + + using src_vector_container = vector_type_maker_t; + using src_vector_container_t = typename src_vector_container::type; + + src_vector_container src_vector = + src_vector_container{src_buf.template Get( + src_coord_.GetOffset() / PackedSize + LoadOffset, true)}; + + static_for<0, VectorLoadSize / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if + // needed + src_element_op_( + op_r_v.template AsType()(idx + LoadOffset), + src_vector.template AsType()[idx]); + }); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, + op_r_v.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ constexpr auto + GetSrcThreadScratchIdx(Number thread_scratch_id = Number{}) + { + using vector_t = typename vector_type_maker::type::type; + return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType(SeqIdx{}); + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; + }); +#else + // OOB Check + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using vector_t = typename vector_type_maker::type::type; + + auto op_r = src_thread_scratch_tuple_(thread_scratch_id) + .template GetAsType(src_data_idx_seq); + + const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) + .template GetAsType(src_data_idx_seq); + + auto op_r_v = is_src_valid ? op_r : vector_t(0); + + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, op_r_v); + }); + + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + static_assert(!is_same_v, pk_i4_t>, + "in-register transpose is not supported for pk_i4_t"); + static_assert(!is_same_v, f4x2_pk_t>, + "in-register transpose is not supported for f4x2_pk_t"); + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + constexpr auto packed_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; + + static_ford{}([&](auto idx) { + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // if there is oob check, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset() / PackedSize, + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + return make_naive_tensor_descriptor_packed(src_access_lengths); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer; + + using SrcOOBThreadScratch = + StaticTensorTupleOfVectorBuffer; + + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + StaticallyIndexedArray src_thread_scratch_tuple_; + StaticallyIndexedArray src_oob_thread_scratch_tuple_; + + DstThreadScratch dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp new file mode 100755 index 000000000000..174b82f8700f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp @@ -0,0 +1,1066 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" + +namespace ck { + +namespace detail { +// TODO: How to fix this? It uses an struct instead of lambda because lambda +// doesn't have constructor +template +struct lambda_scalar_per_access_for_src_and_dst_idle +{ + __host__ __device__ constexpr auto operator()(index_t i) const + { + if(i == SrcVectorDim && i == DstVectorDim) + { + return math::lcm(SrcScalarPerVector, DstScalarPerVector); + } + else if(i == SrcVectorDim) + { + return SrcScalarPerVector; + } + else if(i == DstVectorDim) + { + return DstScalarPerVector; + } + else + { + return 1; + } + } +}; + +} // namespace detail + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +// 5. Dequantization happened between read and write. +template +struct ThreadwiseTensorSliceTransfer_v3r1_dequant +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using ScaleCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_dequant( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const ScaleDesc& scale_desc, + const Index& scale_slice_origin, + const ScaleElementwiseOperation& scale_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + scale_coord_(make_tensor_coordinate(scale_desc, scale_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + scale_element_op_(scale_element_op), + dst_element_op_(dst_element_op) + { + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetScaleSliceOrigin(const ScaleDesc& scale_desc, + const Index& scale_slice_origin_idx) + { + scale_coord_ = make_tensor_coordinate(scale_desc, scale_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType( + src_data_idx_seq, src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void RunScaleRead(const ScaleDesc& scale_desc, const ScaleBuffer& scale_buf) + { + static_assert(ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + ScaleBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! ScaleBuffer and ScaleData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_scale_access_lengths = + container_reorder_given_new2old(scale_access_lengths, scale_dim_access_order); + + // make forward steps + const auto scale_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto scale_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -scale_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(scale_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_scale_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_scale_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_scale_access_lengths[j] + ordered_scale_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate scale data index + constexpr auto scale_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_scale_access_idx[i] + : ordered_scale_access_lengths[i] - 1 - + ordered_scale_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, scale_dim_access_order) * + scale_scalar_per_access; + }(); + + constexpr auto scale_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_scale_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + scale_desc, scale_coord_); + + using scale_vector_type = vector_type_maker_t; + using scale_vector_t = typename scale_vector_type::type; + + // copy data from scale_buf into scale_vector_container + auto scale_vector_container = scale_vector_type{ + scale_buf.template Get(scale_coord_.GetOffset(), is_scale_valid)}; + + // copy data from scale_vector_container into scale_thread_scratch_ + scale_thread_scratch_.template SetAsType( + scale_data_idx_seq, scale_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = + ordered_scale_access_idx[i] < ordered_scale_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_scale_access_idx[j] == ordered_scale_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move scale coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_forward_steps[scale_dim_access_order[i]]); + } + else + { + move_tensor_coordinate(scale_desc, + scale_coord_, + scale_backward_steps[scale_dim_access_order[i]]); + } + } + }); + }); + + // don't need to move scale coordinate back to slice origin + /* + if constexpr(SrcResetCoordinateAfterRun) + { + const auto scale_reset_step = + make_tensor_coordinate_step(scale_desc, GetScaleCoordinateResetStep()); + + move_tensor_coordinate(scale_desc, scale_coord_, scale_reset_step); + } + */ + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + // convert from SrcData to DstData here + dst_thread_scratch_(idx) = + type_convert(src_thread_scratch_tuple_[thread_scratch_id][idx]); + }); +#else + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + + // Do fast numeric convert + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst_idle{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using src_converted_vector_type = vector_type_maker_t; + using src_converted_vector_t = typename src_converted_vector_type::type; + // Vector-wise type convert + static_ford{}([&](auto access_idx) { + auto src_vector_container = src_vector_type{ + src_thread_scratch_tuple_[thread_scratch_id].template GetAsType( + access_idx)}; + + auto src_converted_vector_container = + src_converted_vector_type{fast_numeric_converter(src_vector_container)}; + + src_converted_thread_scratch_.template SetAsType( + access_idx, + src_converted_vector_container.template AsType()[I0]); + }); + + // Element-scale operation, expect packed multiplication + static_ford{}([&](auto idx) { + DstData dst_v; + constexpr auto scale_idx = Sequence{}; + // printf("Tid: %03d, scale: %04x\n", get_thread_local_1d_id(), + // *(reinterpret_cast(&scale_thread_scratch_[scale_idx]))); + src_element_op_(dst_v, + src_converted_thread_scratch_[idx] * scale_thread_scratch_[scale_idx]); + dst_thread_scratch_(idx) = dst_v; + }); +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetScaleThreadScratchDescriptor() + { + + constexpr auto scale_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto scale_access_lengths = SliceLengths{} / scale_scalar_per_access; + + constexpr auto scale_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(scale_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(scale_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(scale_access_lengths_and_vector_length[i], + scale_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(scale_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto scale_thread_scratch_desc_ = + decltype(GetScaleThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + /* + template + struct ScaleThreadScratchDesc{}; + */ + + // Registers, contain raw data loaded from global buffer + using SrcThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain fast converted data + using SrcThreadConvertedScratch = + StaticTensorTupleOfVectorBuffer; + + // Registers, contain scale data + using ScaleThreadScratch = StaticTensorTupleOfVectorBuffer; + + // Registers, contain dequantized data + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + using FastTypeConverter = tensor_operation::element_wise:: + FastNumericArrayConverter; + + StaticallyIndexedArray src_thread_scratch_tuple_; + SrcThreadConvertedScratch src_converted_thread_scratch_; + ScaleThreadScratch scale_thread_scratch_; + + DstThreadScratch dst_thread_scratch_; + FastTypeConverter fast_numeric_converter; + + SrcCoord src_coord_; + ScaleCoord scale_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const ScaleElementwiseOperation scale_element_op_; + const DstElementwiseOperation dst_element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp new file mode 100755 index 000000000000..50f1e21beb35 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -0,0 +1,943 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r1_gather +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + static constexpr auto I8 = Number<8>{}; + static constexpr auto I10 = Number<10>{}; + static constexpr auto I12 = Number<12>{}; + static constexpr auto I13 = Number<13>{}; + static constexpr auto I14 = Number<14>{}; + static constexpr auto I16 = Number<16>{}; + + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr auto SrcScalarPerVector = Number{}; + static constexpr auto DstScalarPerVector = Number{}; + + static constexpr index_t gather_num = SliceLengths{}.At(Number{}); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1_gather( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const SrcElementwiseOperation& src_element_op, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const DstElementwiseOperation& dst_element_op, + const StaticallyIndexedArray& gather_offsets) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + src_element_op_(src_element_op), + dst_element_op_(dst_element_op), + gather_offsets_(gather_offsets) + { + if constexpr((packed_size_v) > 1) + { + static_assert(is_same_v, remove_cvref_t>, + "SrcData != DstData"); + + static_assert( + SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0, + "SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type"); + + static_assert(SrcVectorDim == DstVectorDim, + "Packed data type does not support transpose"); + } + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + + auto adjusted_origin_idx = [&]() { + Index idx; + static_for<0, nDim, 1>{}([&](auto i) { + idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number{}]; + }); + return idx; + }(); + src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void RunRead(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + Number thread_scratch_id = Number{}) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector_) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + constexpr auto ordered_gather_dim = src_dim_access_order[GatherDim]; + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(src_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + auto gather_offset = + gather_offsets_(ordered_src_access_idx[Number{}]); + + const IndexType ld_offset = src_coord_.GetOffset() / PackedSize + gather_offset; + src_oob_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, true); + + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + auto src_vector_container = + src_vector_type{src_buf.template Get(ld_offset, true)}; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + dst_vector_type op_r_v; + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + else if constexpr(is_detected::value) + { + if constexpr(decltype(src_element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + else + { + return 1; + } + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + using src_elem_op_vec_t = typename vector_type::type; + using dst_elem_op_vec_t = typename vector_type::type; + + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto idx) { + // apply the src elementwise op and convert to DstData under the hood if needed + src_element_op_(op_r_v.template AsType()(idx), + src_vector_container.template AsType()[idx]); + }); + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, + op_r_v.template AsType()[I0]); + + auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + move_on_dim_(i) &= i.value != ordered_gather_dim; + }); + + return move_on_dim_; + } + (); + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ constexpr auto + GetSrcThreadScratchIdx(Number thread_scratch_id = Number{}) + { + using vector_t = typename vector_type_maker::type::type; + return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType(SeqIdx{}); + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { +#if !CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + static_ford{}([&](auto idx) { + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; + }); +#else + + // OOB Check + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + constexpr auto src_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using vector_t = typename vector_type_maker::type::type; + + auto op_r = src_thread_scratch_tuple_(thread_scratch_id) + .template GetAsType(src_data_idx_seq); + + auto op_r_v = op_r; + + src_thread_scratch_tuple_(thread_scratch_id) + .template SetAsType(src_data_idx_seq, op_r_v); + }); + + // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ + // TODO make this logic more generic for more sub-dword datatype + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + static_assert(!is_same_v, pk_i4_t>, + "in-register transpose is not supported for pk_i4_t"); + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + static_assert(SrcVectorDim != DstVectorDim, "wrong"); + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return src_thread_scratch_tuple_[thread_scratch_id].GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + constexpr auto packed_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access; + + static_ford{}([&](auto idx) { + dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; + }); + } +#endif + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, + DstBuffer& dst_buf, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // if there is oob check, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, forward_step_idx); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0; + }); + + return make_tensor_coordinate_step(dst_desc, backward_step_idx); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + constexpr auto dst_data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_.template GetAsType(dst_data_idx_seq)}; + + static_for<0, DstScalarPerVector, 1>{}([&](auto i) { + DstData dst_v; + + // apply DstElementwiseOperation + dst_element_op_(dst_v, dst_vector_container.template AsType()[i]); + + dst_vector_container.template AsType()(i) = dst_v; + }); + + // copy data from dst_vector_container to dst_buf + dst_buf.template Set( + dst_coord_.GetOffset() / PackedSize, + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { + reset_src_data_step_(i) = i.value == GatherDim ? 0 : -src_data_idx[i]; + }); + + return reset_src_data_step_; + }(); + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetSrcOOBThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + return make_naive_tensor_descriptor_packed(src_access_lengths); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + private: + static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto src_oob_thread_scratch_desc_ = + decltype(GetSrcThreadScratchDescriptor()){}; + static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; + + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer; + + using SrcOOBThreadScratch = + StaticTensorTupleOfVectorBuffer; + + using DstThreadScratch = StaticTensorTupleOfVectorBuffer; + + StaticallyIndexedArray src_thread_scratch_tuple_; + StaticallyIndexedArray src_oob_thread_scratch_tuple_; + + DstThreadScratch dst_thread_scratch_; + + SrcCoord src_coord_; + DstCoord dst_coord_; + const SrcElementwiseOperation src_element_op_; + const DstElementwiseOperation dst_element_op_; + StaticallyIndexedArray gather_offsets_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp new file mode 100755 index 000000000000..f0d793456d33 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp @@ -0,0 +1,804 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor/static_tensor.hpp" +#include "ck/utility/is_detected.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template +struct ThreadwiseTensorSliceTransfer_v3r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v3r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + src_coords_(src_i) = + make_tensor_coordinate(src_descs.At(src_i), src_slice_origin_idxs[src_i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + dst_coords_(dst_i) = + make_tensor_coordinate(dst_descs.At(dst_i), dst_slice_origin_idxs[dst_i]); + }); + } + + template + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access_tuple = generate_tuple( + [&](auto src_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return SliceLengths{} / src_scalar_per_access_tuple.At(src_i); + static_assert( + SliceLengths::At(SrcVectorDim) % SrcsScalarPerVector::At(src_i) == 0, + "SliceLengths[SrcVectorDim] must be divisible by SrcsScalarPerVector"); + }, + Number{}); + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths_tuple = generate_tuple( + [&](auto src_i) { + return container_reorder_given_new2old(src_access_lengths_tuple.At(src_i), + src_dim_access_order); + }, + Number{}); + + // make forward steps + const auto src_forward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? src_scalar_per_access_tuple.At(src_i)[i] : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto src_backward_steps_tuple = generate_tuple( + [&](auto src_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -src_scalar_per_access_tuple.At(src_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(src_descs.At(src_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nSrc, 1>{}([&](auto src_i) { + static_ford>{}( + [&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths_tuple[j] + + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_src_access_idx[i] + : ordered_src_access_lengths_tuple.At(src_i)[i] - + 1 - ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access_tuple.At(src_i); + }(); + + constexpr auto src_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_descs.At(src_i), src_coords_.At(src_i)); + + using src_vector_type = vector_type_maker_t, + SrcsScalarPerVector::At(src_i)>; + using src_vector_t = typename src_vector_type::type; + + // copy data from src_buf into src_vector_container + auto src_vector_container = + src_vector_type{src_bufs.At(src_i).template Get( + src_coords_.At(src_i).GetOffset(), is_src_valid)}; + + // copy data from src_vector_container into src_thread_scratch_ + src_thread_scratch_tuple_(thread_scratch_id) + .At(src_i) + .template SetAsType( + src_data_idx_seq, + src_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < + ordered_src_access_lengths_tuple.At(src_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == + ordered_src_access_lengths_tuple.At(src_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move src coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_forward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_descs.At(src_i), + src_coords_.At(src_i), + src_backward_steps_tuple.At(src_i)[src_dim_access_order[i]]); + } + } + }); + }); + }); + + static_for<0, nSrc, 1>{}([&](auto src_i) { + // move src coordinate back to slice origin (or not) + if constexpr(SrcsResetCoordinateAfterRun::At(src_i)) + { + const auto src_reset_step = make_tensor_coordinate_step( + src_descs.At(src_i), GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), src_reset_step); + } + }); + } + + template + __device__ void + TransferDataFromSrcThreadScratchToDstThreadScratch(Number thread_scratch_id) + { + // TODO: Add support for CK_EXPERIMENTAL_USE_IN_REGISTER_SUB_DWORD_TRANSPOSE + // (it requires to add Elementwise support in transpose_vectors) + static_ford{}([&](auto idx) { + const auto src_data_refs = generate_tie( + [&](auto src_i) -> const auto& { + return src_thread_scratch_tuple_[thread_scratch_id].At(src_i)[idx]; + }, + Number{}); + + auto dst_data_refs = generate_tie( + [&](auto dst_i) -> auto& { return dst_thread_scratch_tuple_.At(dst_i)(idx); }, + Number{}); + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + } + + template + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers& dst_bufs, + Number thread_scratch_id = Number{}) + { + // if there is transpose, it's done here + // TODO move this elsewhere + TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); + + // src scalar per access on each dim + // TODO: don't use this + constexpr auto dst_scalar_per_access_tuple = generate_tuple( + [&](auto dst_i) { + return generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + }, + Number{}); + + constexpr auto dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { return SliceLengths{} / dst_scalar_per_access_tuple.At(dst_i); }, + Number{}); + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths_tuple = generate_tuple( + [&](auto dst_i) { + return container_reorder_given_new2old(dst_access_lengths_tuple.At(dst_i), + dst_dim_access_order); + }, + Number{}); + + // make forward steps + const auto dst_forward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = + (i.value == j.value) ? dst_scalar_per_access_tuple.At(dst_i)[i] : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), forward_step_idx); + }, + Number{}); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps_tuple = generate_tuple( + [&](auto dst_i) { + return generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) + ? -dst_scalar_per_access_tuple.At(dst_i)[i] + : 0; + }); + + return make_tensor_coordinate_step(dst_descs.At(dst_i), backward_step_idx); + }, + Number{}); + }, + Number{}); + + // loop over tensor and copy + static_for<0, nDst, 1>{}([&](auto dst_i) { + static_ford>{}( + [&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_idx[I0]; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths_tuple.At(dst_i)[j] + + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] + ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths_tuple.At(dst_i)[i] - + 1 - ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access_tuple.At(dst_i); + }(); + + constexpr auto dst_data_idx_seq = + generate_sequence_v2([&](auto i) { return Number{}; }, + Number{}); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid( + dst_descs.At(dst_i), dst_coords_.At(dst_i)); + + using dst_vector_type = vector_type_maker_t, + DstsScalarPerVector::At(dst_i)>; + using dst_vector_t = typename dst_vector_type::type; + + // copy data from dst_thread_scratch_ into dst_vector_container + auto dst_vector_container = dst_vector_type{ + dst_thread_scratch_tuple_.At(dst_i).template GetAsType( + dst_data_idx_seq)}; + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(dst_i.value)); + + // copy data from dst_vector_container to dst_buf + dst_bufs.At(dst_i).template Update( + dst_coords_.At(dst_i).GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < + ordered_dst_access_lengths_tuple.At(dst_i)[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == + ordered_dst_access_lengths_tuple.At(dst_i)[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move dst coord + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_forward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_descs.At(dst_i), + dst_coords_.At(dst_i), + dst_backward_steps_tuple.At(dst_i)[dst_dim_access_order[i]]); + } + } + }); + }); + }); + + // move dst coordinate back to slice origin (or not) + static_for<0, nDst, 1>{}([&](auto dst_i) { + if constexpr(DstsResetCoordinateAfterRun::At(dst_i)) + { + const auto dst_reset_step = make_tensor_coordinate_step( + dst_descs.At(dst_i), GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), dst_reset_step); + } + }); + } + + template + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_scalar_per_access; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + template + __device__ static constexpr auto GetDstCoordinateResetStep() + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<1, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_scalar_per_access.At(dst_i); + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + const Index& src_slice_origin_step_idx) + { + static_for<0, nSrc, 1>{}([&](auto src_i) { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcsResetCoordinateAfterRun::At(src_i) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(src_descs.At(src_i), adjusted_step_idx); + + move_tensor_coordinate(src_descs.At(src_i), src_coords_.At(src_i), adjusted_step); + }); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + const Index& dst_slice_origin_step_idx) + { + static_for<0, nDst, 1>{}([&](auto dst_i) { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstsResetCoordinateAfterRun::At(dst_i) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(dst_descs.At(dst_i), adjusted_step_idx); + + move_tensor_coordinate(dst_descs.At(dst_i), dst_coords_.At(dst_i), adjusted_step); + }); + } + + template + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(src_access_lengths), + Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + template + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, + Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = + container_push_back(sequence_to_tuple_of_number(dst_access_lengths), + Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto MakeSrcThreadScratchTuple() + { + return generate_tuple( + [&](auto src_i) { + constexpr auto src_thread_scratch_desc = + decltype(GetSrcThreadScratchDescriptor()){}; + using SrcThreadScratch = + StaticTensorTupleOfVectorBuffer, + SrcsScalarPerVector::At(src_i), + decltype(src_thread_scratch_desc), + true>; + return SrcThreadScratch{}; + }, + Number{}); + } + + __device__ static constexpr auto MakeDstThreadScratchTuple() + { + return generate_tuple( + [&](auto dst_i) { + constexpr auto dst_thread_scratch_desc = + decltype(GetDstThreadScratchDescriptor()){}; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer, + DstsScalarPerVector::At(dst_i), + decltype(dst_thread_scratch_desc), + true>; + return DstThreadScratch{}; + }, + Number{}); + } + + private: + using SrcThreadScratchTuple = decltype(MakeSrcThreadScratchTuple()); + using DstThreadScratchTuple = decltype(MakeDstThreadScratchTuple()); + + StaticallyIndexedArray src_thread_scratch_tuple_; + + DstThreadScratchTuple dst_thread_scratch_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp new file mode 100755 index 000000000000..6a6c1f2ac5bf --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v4r1.hpp @@ -0,0 +1,175 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" + +namespace ck { +// Assume: +// 1. src: +// 1. SrcDesc is known at compile-time +// 2. SrcBuffer is DynamicBuffer +// 3. src_ref_idx is known at run-time +// 4. SrcRefToOriginDisplacement is known at compile-time +// 5. use #-step +// 2. dst: +// 1. DstDesc is known at compile-time +// 2. DstBuffer is StaticBuffer +// 3. DstOriginIdx is known at compile-time +// 4. use direct address calculation +// 3. vector access on src +template ::type = false> +struct ThreadwiseTensorSliceTransfer_v4r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v4r1(const Index& src_ref_idx) + : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!"); + }); + } + + template + __device__ void Run(const SrcDesc&, + const SrcRefToOriginDisplacement&, + const SrcBuffer& src_buf, + const DstDesc&, + const DstOriginIdx&, + DstBuffer& dst_buf) const + { + static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), + "wrong! SrcDesc and DstDesc need to known at compile-time"); + + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); + + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); + + // SrcDesc and DstDesc are known at compile-time + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; + + // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time + constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); + constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{}); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto dim_access_order = DimAccessOrder{}; + + constexpr auto ordered_access_lengths = + container_reorder_given_new2old(access_lengths, dim_access_order); + + static_ford{}([&](auto ordered_access_idx) { + // position in slice window + constexpr auto data_to_origin_disp_idx = + ordered_access_idx.ReorderGivenOld2New(dim_access_order) * + src_vector_tensor_lengths; + + // src coordinate at starting point of src_vector + constexpr auto src_ref_to_data_disp_idx = + src_ref_to_origin_disp_idx + data_to_origin_disp_idx; + + constexpr auto src_ref_to_data_disp_coord_step = + make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx); + + auto src_data_coord = src_ref_coord_; + + move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid( + src_desc, src_data_coord); + + // copy data from src_buf into src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_data_coord.GetOffset(), is_src_valid); + + // copy data from src_vector into dst_buf (also cast from SrcData to DstData) + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t dst_offset = dst_desc.CalculateOffset( + dst_origin_idx + data_to_origin_disp_idx + src_vector_idx); + + dst_buf(Number{}) = type_convert( + src_vector.template AsType()[Number{}]); + }); + }); + } + + template + __device__ void MoveSrcSliceWindow(const SrcDesc&, + const SrcSliceMoveStepIdx& src_slice_move_step_idx) + { + constexpr auto src_desc = SrcDesc{}; + + const auto src_slice_move_step_iter = + make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx)); + + move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter); + } + + private: + SrcCoord src_ref_coord_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp new file mode 100755 index 000000000000..40ebdeff08e8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v5r1.hpp @@ -0,0 +1,614 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +// 4. Use thread buffer +template // control whether to move back dst coordinate after each + // RunWrite(), will be fused with MoveDstSliceWindow to + // save addr computation +struct ThreadwiseTensorSliceTransfer_v5r1 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); + using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); + + __device__ constexpr ThreadwiseTensorSliceTransfer_v5r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)) + { + // TODO: fix this + static_assert(is_same::value, + "wrong! current implementation assume SrcData and DstData are same type"); + + static_for<0, nDim, 1>{}([](auto i) { + static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 && + SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0, + "wrong!"); + }); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void + RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks) + { + static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); + + // tensor descriptor for src_vector + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(src_vector_tensor_lengths, + SrcVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + SrcVectorTensorContiguousDimOrder{}); + + constexpr auto src_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(src_vector_tensor_lengths), + sequence_to_tuple_of_number(src_vector_tensor_strides)); + + // access order and lengths + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // make forward steps + const auto src_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, forward_step_idx, src_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto src_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + src_desc, backward_step_idx, src_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_src_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_idx[I0]; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i] + : ordered_src_access_lengths[i] - 1 - + ordered_src_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + vector_type_maker_t src_vector; + + using src_vector_t = typename decltype(src_vector)::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf to src_vector + src_vector.template AsType()(I0) = + src_buf.template Get(src_coord_.GetOffset(), is_src_valid); + + // copy data from src_vector to buffer_ + static_ford{}([&](auto src_vector_idx_) { + constexpr auto src_vector_idx = to_multi_index(src_vector_idx_); + + constexpr index_t src_vector_offset = + src_vector_desc.CalculateOffset(src_vector_idx); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx); + + buffer_(Number{}) = + src_vector.template AsType()[Number{}]; + }); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]); + } + } + }); + }); + + // move src coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + } + + template + __device__ void + RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks) + { + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or + DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "wrong!"); + + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); + + // tensor descriptor for dst_vector + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new( + container_reverse_exclusive_scan( + container_reorder_given_new2old(dst_vector_tensor_lengths, + DstVectorTensorContiguousDimOrder{}), + math::multiplies{}, + I1), + DstVectorTensorContiguousDimOrder{}); + + constexpr auto dst_vector_desc = + make_naive_tensor_descriptor(sequence_to_tuple_of_number(dst_vector_tensor_lengths), + sequence_to_tuple_of_number(dst_vector_tensor_strides)); + + // dst access order and lengths + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // make forward steps + const auto dst_forward_steps = generate_tuple( + [&](auto i) { + Index forward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + forward_step_idx(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, forward_step_idx, dst_step_hacks[I0][i]); + }, + Number{}); + + // make backward steps + const auto dst_backward_steps = generate_tuple( + [&](auto i) { + Index backward_step_idx; + + static_for<0, nDim, 1>{}([&](auto j) { + backward_step_idx(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0; + }); + + return make_tensor_coordinate_step( + dst_desc, backward_step_idx, dst_step_hacks[I1][i]); + }, + Number{}); + + // loop over tensor and copy + static_ford{}([&](auto ordered_dst_access_idx) { + // judge move forward or move backward + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = 0; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j]; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i] + : ordered_dst_access_lengths[i] - 1 - + ordered_dst_access_idx[i]; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + vector_type_maker_t dst_vector; + + // copy data from buffer_ to dst_vector (also cast from SrcData to DstData) + static_ford{}([&](auto dst_vector_idx_) { + constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_); + + constexpr index_t buffer_offset = + buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx); + + constexpr index_t dst_vector_offset = + dst_vector_desc.CalculateOffset(dst_vector_idx); + + dst_vector.template AsType()(Number{}) = + type_convert(buffer_[Number{}]); + }); + + using dst_vector_t = typename decltype(dst_vector)::type; + + // copy data from dst_vector to dst_buf + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + + constexpr auto move_on_dim = [&]() constexpr + { + StaticallyIndexedArray move_on_dim_; + + static_for<0, nDim, 1>{}([&](auto i) { + move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1; + + static_for{}([&](auto j) { + move_on_dim_(i) &= + ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1; + }); + }); + + return move_on_dim_; + } + (); + + // move + static_for<0, nDim, 1>{}([&](auto i) { + if constexpr(move_on_dim[i]) + { + if constexpr(forward_sweep[i]) + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]); + } + else + { + move_tensor_coordinate( + dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]); + } + } + }); + }); + + // move dst coordinate back to slice origin (or not) + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + template + __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf) + { + constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto src_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunRead(src_desc, src_buf, src_step_hacks); + } + + template + __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf) + { + constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform(); + + constexpr auto zeros = typename uniform_sequence_gen::type{}; + + constexpr auto dst_step_hacks = + make_tuple(generate_tuple([&](auto) { return zeros; }, Number{}), + generate_tuple([&](auto) { return zeros; }, Number{})); + + RunWrite(dst_desc, dst_buf, dst_step_hacks); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; + + constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths; + + constexpr auto src_dim_access_order = SrcDimAccessOrder{}; + + constexpr auto ordered_src_access_lengths = + container_reorder_given_new2old(src_access_lengths, src_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_src_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate src data index after last iteration in RunRead(), if it has not being reset by + // RunRead() + constexpr auto src_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, src_dim_access_order) * + src_vector_tensor_lengths; + }(); + + // + constexpr auto reset_src_data_step = [&]() { + Index reset_src_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; }); + + return reset_src_data_step_; + }(); + + return reset_src_data_step; + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; + + constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths; + + constexpr auto dst_dim_access_order = DstDimAccessOrder{}; + + constexpr auto ordered_dst_access_lengths = + container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order); + + // judge move forward or move backward during the last iteration + constexpr auto forward_sweep = [&]() { + StaticallyIndexedArray forward_sweep_; + + forward_sweep_(I0) = true; + + static_for<1, nDim, 1>{}([&](auto i) { + index_t tmp = ordered_dst_access_lengths[I0] - 1; + + static_for<0, i, 1>{}([&](auto j) { + tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1; + }); + + forward_sweep_(i) = tmp % 2 == 0; + }); + + return forward_sweep_; + }(); + + // calculate dst data index after last iteration in RunWrite(), if it has not being reset by + // RunWrite() + constexpr auto dst_data_idx = [&]() { + Index ordered_idx; + + static_for<0, nDim, 1>{}([&](auto i) { + ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0; + }); + + return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) * + dst_vector_tensor_lengths; + }(); + + // + constexpr auto reset_dst_data_step = [&]() { + Index reset_dst_data_step_; + + static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; }); + + return reset_dst_data_step_; + }(); + + return reset_dst_data_step; + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void + MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx, + const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRun ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step( + src_desc, adjusted_step_idx, src_move_slice_window_step_hack); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by RunWrite(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRun ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + static constexpr auto buffer_desc_ = + make_naive_tensor_descriptor_packed(sequence_to_tuple_of_number(SliceLengths{})); + + static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize(); + + StaticBuffer buffer_; + + SrcCoord src_coord_; + DstCoord dst_coord_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp new file mode 100755 index 000000000000..644877d3931f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1(const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over space-filling curve + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + DstData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = v; + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp new file mode 100755 index 000000000000..88ed21754708 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp @@ -0,0 +1,213 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r1r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2( + const SrcDesc& src_desc, + const Index& src_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + { + src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const SrcDesc& src_desc, + const SrcBuffer& src_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + // loop over space-filling curve + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src_vector_type = vector_type_maker_t; + using src_vector_t = typename src_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); + + // copy data from src_buf into src_vector_container + auto src_vector_container = src_vector_type{ + src_buf.template Get(src_coord_.GetOffset(), is_src_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + SrcData v; + + // apply element-wise operation + element_op_(v, src_vector_container.template AsType()[i]); + + // apply type convert + dst_vector_container.template AsType()(i) = type_convert(v); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(SrcResetCoordinateAfterRun) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src_desc, src_coord_, src_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRun + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx); + + move_tensor_coordinate(src_desc, src_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + SrcCoord src_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp new file mode 100755 index 000000000000..cf2c7a2aee3d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r2.hpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r2 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r2(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + // copy data from dst_vector into dst_buf + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp new file mode 100755 index 000000000000..b5847e51b42e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r3.hpp @@ -0,0 +1,310 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory +// and sometimes useless instructions: +// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument +// instead +// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same +// tensor coordinate instead +// 3. Don't use a pointer to VGPR buffer, use vector instead + +// Assume: +// 1. src0_desc and dst_desc are not known at compile-time +// 2. SrcBuffer and DstBuffer are DynamicBuffer +// 3. src_slice_origin and dst_slice_origin are not known at compile-time, +template +struct ThreadwiseTensorSliceTransfer_v6r3 +{ + static constexpr index_t nDim = SliceLengths::Size(); + + using Index = MultiIndex; + + using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{})); + using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{})); + using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{})); + using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{})); + + static constexpr auto I0 = Number<0>{}; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc, + const Index& src0_slice_origin, + const Src1Desc& src1_desc, + const Index& src1_slice_origin, + const Src2Desc& src2_desc, + const Index& src2_slice_origin, + const DstDesc& dst_desc, + const Index& dst_slice_origin, + const ElementwiseOperation& element_op) + : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)), + src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)), + src2_coord_(make_tensor_coordinate(src2_desc, src2_slice_origin)), + dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc, + const Index& src0_slice_origin_idx) + { + src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx); + } + + __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc, + const Index& src1_slice_origin_idx) + { + src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx); + } + + __device__ void SetSrc2SliceOrigin(const Src2Desc& src2_desc, + const Index& src2_slice_origin_idx) + { + src2_coord_ = make_tensor_coordinate(src2_desc, src2_slice_origin_idx); + } + + __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx) + { + dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx); + } + + template + __device__ void Run(const Src0Desc& src0_desc, + const Src0Buffer& src0_buf, + const Src1Desc& src1_desc, + const Src1Buffer& src1_buf, + const Src2Desc& src2_desc, + const Src2Buffer& src2_buf, + const DstDesc& dst_desc, + DstBuffer& dst_buf) + { + // scalar per access on each dim + // TODO: don't use lambda_scalar_per_access + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto idx_1d) { + using src0_vector_type = vector_type_maker_t; + using src0_vector_t = typename src0_vector_type::type; + + using src1_vector_type = vector_type_maker_t; + using src1_vector_t = typename src1_vector_type::type; + + using src2_vector_type = vector_type_maker_t; + using src2_vector_t = typename src2_vector_type::type; + + using dst_vector_type = vector_type_maker_t; + using dst_vector_t = typename dst_vector_type::type; + + const bool is_src0_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src0_desc, src0_coord_); + + const bool is_src1_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src1_desc, src1_coord_); + + const bool is_src2_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src2_desc, src2_coord_); + + // copy data from src0_buf into src0_vector_container + auto src0_vector_container = src0_vector_type{ + src0_buf.template Get(src0_coord_.GetOffset(), is_src0_valid)}; + + auto src1_vector_container = src1_vector_type{ + src1_buf.template Get(src1_coord_.GetOffset(), is_src1_valid)}; + + auto src2_vector_container = src2_vector_type{ + src2_buf.template Get(src2_coord_.GetOffset(), is_src2_valid)}; + + auto dst_vector_container = dst_vector_type{}; + + // apply pointwise operation + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + element_op_(dst_vector_container.template AsType()(i), + src0_vector_container.template AsType()[i], + src1_vector_container.template AsType()[i], + src2_vector_container.template AsType()[i]); + }); + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); + + dst_buf.template Update( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector_container.template AsType()[I0]); + + // move coordinate + if constexpr(idx_1d.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d); + move_tensor_coordinate( + src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step)); + move_tensor_coordinate( + src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step)); + move_tensor_coordinate( + src2_desc, src2_coord_, make_tensor_coordinate_step(src2_desc, forward_step)); + move_tensor_coordinate( + dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step)); + } + }); + + // move coordinate back to slice origin (or not) + if constexpr(Src0ResetCoordinateAfterRun) + { + const auto src0_reset_step = + make_tensor_coordinate_step(src0_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step); + } + + if constexpr(Src1ResetCoordinateAfterRun) + { + const auto src1_reset_step = + make_tensor_coordinate_step(src1_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step); + } + + if constexpr(Src2ResetCoordinateAfterRun) + { + const auto src2_reset_step = + make_tensor_coordinate_step(src2_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(src2_desc, src2_coord_, src2_reset_step); + } + + if constexpr(DstResetCoordinateAfterRun) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep()); + + move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step); + } + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = SpaceFillingCurve>; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc, + const Index& src0_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src0ResetCoordinateAfterRun + ? src0_slice_origin_step_idx + : src0_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx); + + move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc, + const Index& src1_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src1ResetCoordinateAfterRun + ? src1_slice_origin_step_idx + : src1_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx); + + move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc, + const Index& src2_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = Src2ResetCoordinateAfterRun + ? src2_slice_origin_step_idx + : src2_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx); + + move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRun + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx); + + move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step); + } + + private: + Src0Coord src0_coord_; + Src1Coord src1_coord_; + Src2Coord src2_coord_; + DstCoord dst_coord_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp new file mode 100755 index 000000000000..db7dee21992d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7.hpp @@ -0,0 +1,298 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" + +namespace ck { + +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename DimAccessOrder, + index_t VectorDim, + index_t ScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags> // Sequence +struct ThreadwiseTensorSliceTransfer_v7 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SpaceFillingCurve = + SpaceFillingCurve>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % ScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + auto generate_vectors = [&](auto data_types) { + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + }; + + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + // loop over space-filling curve + static_for<0, num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(SrcDatas{}); + auto dst_vectors = generate_vectors(DstDatas{}); + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), + is_src_valid); + }); + + // apply pointwise function + static_for<0, ScalarPerVector, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + return dst_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != num_access - 1) + { + constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + __device__ static constexpr auto GetCoordinateResetStep() + { + constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); + + if constexpr(num_access == 0) + { + return typename SpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + SpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + + return reset_step; + } + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp new file mode 100755 index 000000000000..4b277e4383f2 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp @@ -0,0 +1,630 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + index_t SrcScalarPerVector, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r2 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r2( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp new file mode 100755 index 000000000000..ea074144b665 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp @@ -0,0 +1,648 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r3 +{ + static constexpr auto I0 = Number<0>{}; + + static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0]; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r3( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + + if constexpr(SrcScalarPerVectors{}[i] == 1) + { + auto data_types = SrcDatas{}; + using DataType = remove_cvref_t; + const auto tmp = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + + static_for<0, SrcScalarPerVector, 1>{}( + [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); + } + else + { + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + } + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from + // dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + + const bool is_dst_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], + dst_coords_[i]); + + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + + dst_bufs(i).template Update( + dst_coords_[i].GetOffset(), + is_dst_valid, + dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate(dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + return DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp new file mode 100755 index 000000000000..9b1ff3dbf8de --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -0,0 +1,679 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_description/tensor_space_filling_curve.hpp" +#include "ck/utility/is_detected.hpp" +#include "ck/tensor/static_tensor.hpp" + +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_util.hpp" + +namespace ck { +// Thread-level multi-source, multi-destination tensor slice data movement +// Assume: +// 1. All sources and destinations are DynamicBuffer +// 2. Same VectorDim and ScalerPerVector for all sources and destinations +// 3. DstInMemOps are per destination tensor +// 4. ThreadTransferSrcResetCoordinateAfterRunFlags are per source tensor +// 5. ThreadTransferDstResetCoordinateAfterRunFlags are per destination tensor +// 6. Does not need to know src_descs and dst_descs at compile-time +// 7. Does not need to know src_slice_origins and dst_slice_origins at compile-time, +// +// Does following things to avoid scratch memory issue +// 1. Use StaticallyIndexedArray or vector_type instead of C array for thread buffer +// 2. Pass tensor descritpors by reference (or tuple of references) +// 3. Does not keep reference to tensor descriptor +// 4. Does not construct new tensor coordinate when call Run() +template + typename SliceLengths, + typename SrcDimAccessOrder, + typename DstDimAccessOrder, + index_t SrcVectorDim, + index_t DstVectorDim, + typename SrcScalarPerVectors, + index_t DstScalarPerVector, + typename SrcResetCoordinateAfterRunFlags, // Sequence + typename DstResetCoordinateAfterRunFlags, // Sequence + typename IndexType, + index_t ScatterDim = 1, + bool OutputScatter = true, + index_t ScatterWeightIdx = 3, + index_t NumThreadScratch = 1> +struct ThreadwiseTensorSliceTransfer_v7r3_scatter +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto SrcScalarPerVector = SrcScalarPerVectors{}[I0]; + + static constexpr index_t nDim = SliceLengths::Size(); + + static constexpr index_t nSrc = SrcDescs::Size(); + static constexpr index_t nDst = DstDescs::Size(); + + using Index = MultiIndex; + static constexpr index_t scatter_num = SliceLengths{}.At(Number{}); + + // return a tuple of coordiantes for a tuple of tensor + template = false> + static constexpr auto MakeCoordinates(const Descs& descs, const Indices& indices) + { + return generate_tuple([&](auto i) { return make_tensor_coordinate(descs[i], indices[i]); }, + Number{}); + } + + using SrcCoords = decltype(MakeCoordinates(SrcDescs{}, StaticallyIndexedArray{})); + using DstCoords = decltype(MakeCoordinates(DstDescs{}, StaticallyIndexedArray{})); + + // scalar per access on each dim + // FIXME: don't use lambda_scalar_per_access + static constexpr auto src_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + static constexpr auto dst_scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access{}, Number{}); + + using SrcSpaceFillingCurve = SpaceFillingCurve, + false>; + + using DstSpaceFillingCurve = SpaceFillingCurve, + false>; + + __device__ constexpr ThreadwiseTensorSliceTransfer_v7r3_scatter( + const SrcDescs& src_descs, + const StaticallyIndexedArray& src_slice_origins, + const DstDescs& dst_descs, + const StaticallyIndexedArray& dst_slice_origins, + const ElementwiseOperation& element_op) + : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), + dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), + element_op_(element_op) + { + static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, + "wrong! cannot evenly divide"); + + static_assert(SliceLengths::At(Number{}) % DstScalarPerVector == 0, + "wrong! cannot evenly divide"); + } + + template = false> + __device__ void SetSrcSliceOrigins(const SrcDescs& src_descs, + const Indices& src_slice_origin_idxs) + { + static_for<0, nSrc, 1>{}([&](auto i) { + src_coords_(i) = make_tensor_coordinate(src_descs[i], src_slice_origin_idxs[i]); + }); + } + + template = false> + __device__ void SetDstSliceOrigins(const DstDescs& dst_descs, + const Indices& dst_slice_origin_idxs) + { + static_for<0, nDst, 1>{}([&](auto i) { + dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); + }); + } + + template + __device__ static auto generate_vectors() + { + auto data_types = DataTypes{}; + + constexpr index_t num = data_types.Size(); + + return generate_tuple( + [&](auto i) { + using DataType = remove_cvref_t; + + return vector_type_maker_t{}; + }, + Number{}); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + template = false> + __device__ void RunRead(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto src_vectors = generate_vectors(); + auto elm_vectors = generate_vectors(); + + bool oob_val = true; + + // copy data from src_bufs into src_vectors + static_for<0, nSrc, 1>{}([&](auto i) { + using src_vector_t = typename remove_cvref_t::type; + + const bool is_src_valid = + coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], + src_coords_[i]); + + oob_val = oob_val & is_src_valid; + src_vectors(i).template AsType()(I0) = + src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + }); + + constexpr auto get_elem_op_vec_len = []() { + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack8_invocable) + return math::min(8, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack4_invocable) + return math::min(4, SrcScalarPerVector); + } + if constexpr(is_detected::value) + { + if constexpr(decltype(element_op_)::is_pack2_invocable) + return math::min(2, SrcScalarPerVector); + } + return 1; + }; + + constexpr index_t elem_op_vec_len = get_elem_op_vec_len(); + + // apply pointwise function + static_for<0, SrcScalarPerVector / elem_op_vec_len, 1>{}([&](auto i) { + // get reference to src data + const auto src_data_refs = generate_tie( + // return type should be lvalue + [&](auto iSrc) -> const auto& { + using SrcData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return src_vectors[iSrc].template AsType()[i]; + }, + Number{}); + + // get reference to dst data + auto dst_data_refs = generate_tie( + // return type should be lvalue + [&](auto iDst) -> auto& { + using DstData = remove_cvref_t>; + + using elem_op_vec_t = typename vector_type::type; + + return elm_vectors(iDst).template AsType()(i); + }, + Number{}); + + // apply pointwise function + // pointwise function signature: + // element_op_(dst_data_refs[I0], + // dst_data_refs[I1], + // ..., + // src_data_refs[I0], + // src_data_refs[I1], + // ...) + unpack2(element_op_, dst_data_refs, src_data_refs); + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val; + + // move coordinate + if constexpr(iAccess.value != src_num_access - 1) + { + constexpr auto forward_step = SrcSpaceFillingCurve::GetForwardStep(iAccess); + + static_for<0, nSrc, 1>{}([&](auto i) { + move_tensor_coordinate(src_descs[i], + src_coords_(i), + make_tensor_coordinate_step(src_descs[i], forward_step)); + }); + } + }); + + // move coordinate back to slice origin (or not) + static_for<0, nSrc, 1>{}([&](auto i) { + if constexpr(SrcResetCoordinateAfterRunFlags::At(i)) + { + const auto src_reset_step = + make_tensor_coordinate_step(src_descs[i], GetSrcCoordinateResetStep()); + + move_tensor_coordinate(src_descs[i], src_coords_(i), src_reset_step); + } + }); + } + +#if 1 + template + __device__ void OOBCheck(Number thread_scratch_id = Number{}) + { + // loop over space-filling curve + static_for<0, src_num_access, 1>{}([&](auto iAccess) { + auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess]; + auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess]; + + static_for<0, nDst, 1>{}([&](auto i) { + using elm_vector_t = typename remove_cvref_t::type; + elm_vectors(i).template AsType()(I0) = + oob_val ? elm_vectors(i).template AsType()[I0] : elm_vector_t{0}; + }); + + elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors; + }); + } +#endif + + template + __device__ void + TransposeFromElmToDst(Number thread_scratch_id = Number{}) + { + using DstData = remove_cvref_t; + + using ElmThreadScratch = + StaticTensorTupleOfVectorBuffer; + using DstThreadScratch = + StaticTensorTupleOfVectorBuffer; + + ElmThreadScratch elm_thread_scratch_; + DstThreadScratch dst_thread_scratch_; + + elm_thread_scratch_.data_ = + bit_cast(elm_vectors_tuple_[thread_scratch_id]); + + if constexpr(SrcVectorDim != DstVectorDim && + ((is_same>::value && + SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) || + (is_same>::value && + SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) + { + // each transpose does + // DstScalarPerVector # of src vectors in src_thread_scratch_ + // SrcScalarPerVector # of dst vectors in dst_thread_scratch_ + constexpr index_t num_src_vector = Number{}; + constexpr index_t num_dst_vector = Number{}; + + // Assume SrcVectorDim is not the same as DstVectorDim, so we do transpose + // TODO: make this logic generic for all scenario + + constexpr auto src_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto dst_scalar_step_in_vector = generate_sequence( + detail::lambda_scalar_step_in_vector{}, Number{}); + + constexpr auto scalar_per_access = generate_sequence( + detail::lambda_scalar_per_access_for_src_and_dst{}, + Number{}); + + constexpr auto access_lengths = SliceLengths{} / scalar_per_access; + + static_ford{}([&](auto access_idx) { + constexpr auto data_idx = access_idx * scalar_per_access; + + constexpr auto data_idx_seq = generate_sequence_v2( + [&](auto i) { return Number{}; }, Number{}); + + using src_vector_t = vector_type_maker_t; + using dst_vector_t = vector_type_maker_t; + + // get DstScalarPerVector # of read-only references to src vectors from + // src_thread_scratch_ + const auto src_vector_refs = generate_tie( + [&](auto i) -> const src_vector_t& { + // i increment corresponds to movement in DstVectorDim + return elm_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * dst_scalar_step_in_vector); + }, + Number{}); + + // get SrcScalarPerVector # of references to dst vectors from + // dst_thread_scratch_ + auto dst_vector_refs = generate_tie( + [&](auto i) -> dst_vector_t& { + // i increment corresponds to movement in SrcVectorDim + return dst_thread_scratch_.GetVectorTypeReference( + data_idx_seq + i * src_scalar_step_in_vector); + }, + Number{}); + + // do data transpose + transpose_vectors{}( + src_vector_refs, dst_vector_refs); + }); + } + else + { + static_ford{}( + [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); + } + + dst_vectors_tuple_(thread_scratch_id) = bit_cast(dst_thread_scratch_.data_); + } + + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void RunWrite(const DstDescs& dst_descs, + DstBuffers dst_bufs, + StaticallyIndexedArray& scatter_offsets, + Number thread_scratch_id = Number{}) + { + OOBCheck(thread_scratch_id); + TransposeFromElmToDst(thread_scratch_id); + + // loop over space-filling curve + static_for<0, dst_num_access, 1>{}([&](auto iAccess) { + auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess]; + IndexType scatter_offset = 0; + if constexpr(OutputScatter) + { + constexpr auto iScatter = + DstSpaceFillingCurve::GetIndex(iAccess)(Number{}); + scatter_offset = scatter_offsets(Number{}); + } + // copy data from buf_vectors into dst_bufs + static_for<0, nDst, 1>{}([&](auto i) { + using dst_vector_t = typename remove_cvref_t::type; + IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset()); + const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); + constexpr InMemoryDataOperationEnum DstInMemOp = + static_cast(DstInMemOps::At(i.value)); + dst_bufs(i).template Update( + dst_offset, is_dst_valid, dst_vectors[i].template AsType()[I0]); + }); + + // move coordinate + if constexpr(iAccess.value != dst_num_access - 1) + { + constexpr auto forward_step = DstSpaceFillingCurve::GetForwardStep(iAccess); + + auto forward_step_scatter = [&]() constexpr + { + Index step_; + + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i]; + }); + + return step_; + } + (); + static_for<0, nDst, 1>{}([&](auto i) { + move_tensor_coordinate( + dst_descs[i], + dst_coords_(i), + make_tensor_coordinate_step(dst_descs[i], forward_step_scatter)); + }); + } + }); + + static_for<0, nDst, 1>{}([&](auto i) { + if constexpr(DstResetCoordinateAfterRunFlags::At(i)) + { + const auto dst_reset_step = + make_tensor_coordinate_step(dst_descs[i], GetDstCoordinateResetStep()); + + move_tensor_coordinate(dst_descs[i], dst_coords_(i), dst_reset_step); + } + }); + } + + // SrcDescs: Tuple + // SrcBuffers: Tuple + // DstDescs: Tuple + // DstBuffers: Tuple + template = false> + __device__ void Run(const SrcDescs& src_descs, + const SrcBuffers& src_bufs, + const DstDescs& dst_descs, + DstBuffers dst_bufs, + StaticallyIndexedArray& scatter_offsets) + { + RunRead(src_descs, src_bufs); + RunWrite(dst_descs, dst_bufs, scatter_offsets); + } + + __device__ static constexpr auto GetSrcCoordinateResetStep() + { + if constexpr(src_num_access == 0) + { + return typename SrcSpaceFillingCurve::Index{}; + } + else + { + return SrcSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + } + } + + __device__ static constexpr auto GetDstCoordinateResetStep() + { + if constexpr(dst_num_access == 0) + { + return typename DstSpaceFillingCurve::Index{}; + } + else + { + constexpr auto reset_step = + DstSpaceFillingCurve::GetStepBetween(Number{}, Number<0>{}); + auto reset_step_scatter = [&]() constexpr + { + Index step_; + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = + (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number{}]; + }); + + return step_; + } + (); + return reset_step_scatter; + } + } + + __device__ static constexpr auto GetSrcThreadScratchDescriptor() + { + // constexpr auto src_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; + + constexpr auto src_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(src_access_lengths), Number{}); + + // 1st stage of transforms + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(src_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(src_access_lengths_and_vector_length[i], + src_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(src_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == SrcVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + __device__ static constexpr auto GetDstThreadScratchDescriptor() + { + // 1st stage of transforms + // constexpr auto dst_scalar_per_access = generate_sequence( + // detail::lambda_scalar_per_access{}, + // Number{}); + + constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; + + constexpr auto dst_access_lengths_and_vector_length = container_push_back( + sequence_to_tuple_of_number(dst_access_lengths), Number{}); + + constexpr auto desc0 = + make_naive_tensor_descriptor_packed(dst_access_lengths_and_vector_length); + + // 2nd stage of transforms + constexpr auto transforms = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return make_merge_transform_v3_division_mod( + make_tuple(dst_access_lengths_and_vector_length[i], + dst_access_lengths_and_vector_length[Number{}])); + } + else + { + return make_pass_through_transform(dst_access_lengths_and_vector_length[i]); + } + }, + Number{}); + + constexpr auto low_dim_idss = generate_tuple( + [&](auto i) { + if constexpr(i == DstVectorDim) + { + return Sequence{}; + } + else + { + return Sequence{}; + } + }, + Number{}); + + constexpr auto up_dim_idss = + generate_tuple([&](auto i) { return Sequence{}; }, Number{}); + + return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); + } + + // src_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, + Number iSrc, + const Index& src_slice_origin_step_idx) + { + // if src coord was not reset by RunRead(), then need to adjust the step here + const auto adjusted_step_idx = + SrcResetCoordinateAfterRunFlags::At(iSrc) + ? src_slice_origin_step_idx + : src_slice_origin_step_idx + GetSrcCoordinateResetStep(); + + // is it OK to construct a new step every time? + const auto adjusted_step = make_tensor_coordinate_step(src_descs[iSrc], adjusted_step_idx); + + move_tensor_coordinate(src_descs[iSrc], src_coords_(iSrc), adjusted_step); + } + + // dst_slice_origin_step_idx need to be known at compile-time, for performance reason + template + __device__ void MoveDstSliceWindow(const DstDescs& dst_descs, + Number iDst, + const Index& dst_slice_origin_step_idx) + { + // if dst coord was not reset by Run(), then need to adjust the step here + const auto adjusted_step_idx = + DstResetCoordinateAfterRunFlags::At(iDst) + ? dst_slice_origin_step_idx + : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + + auto adjusted_step_idx_scatter = [&]() { + Index step_; + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = + (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number{}]; + }); + + return step_; + }(); + // is it OK to construct a new step every time? + const auto adjusted_step = + make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter); + + move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); + } + + private: + using SrcVectorsType = decltype(generate_vectors()); + using ElmVectorsType = decltype(generate_vectors()); + using DstVectorsType = decltype(generate_vectors()); + + static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); + static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); + + using ElmVectorTuple = StaticallyIndexedArray; + using DstVectorTuple = StaticallyIndexedArray; + + StaticallyIndexedArray elm_vectors_tuple_; + StaticallyIndexedArray dst_vectors_tuple_; + + using OOBVectorTuple = StaticallyIndexedArray; + StaticallyIndexedArray oob_vectors_tuple_; + + SrcCoords src_coords_; + DstCoords dst_coords_; + const ElementwiseOperation element_op_; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp new file mode 100755 index 000000000000..eb6715e8ebb3 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/thread/threadwise_welford.hpp @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/math_v2.hpp" + +namespace ck { + +// Assume +// 1) XDesc is known at compile-time +// 2) MeanVarDesc is known at compile-time +// 3) XBuffer is static buffer +// 4) MeanBuffer is static buffer +// 5) VarBuffer is static buffer +template +struct ThreadwiseWelford +{ + static constexpr auto x_thread_desc_m_k = XThreadDesc_M_K{}; + static constexpr auto mean_var_thread_desc_m = MeanVarThreadDesc_M{}; + + static constexpr auto thread_x_length_m = x_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto thread_x_length_k = x_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto thread_mean_var_length_m = mean_var_thread_desc_m.GetLength(Number<0>{}); + + static_assert(thread_x_length_m == thread_mean_var_length_m, + "lengths of source and mean/var buffer must match!"); + + __device__ constexpr ThreadwiseWelford() : cur_count_(0), max_count_(0) {} + + __device__ inline void Update(T& mean, T& var, T x) + { + using ck::math::isnan; + + if(isnan(x)) + { + mean = x; + var = x; + } + else + { + T delta = x - mean; + mean += delta / cur_count_; + T delta2 = x - mean; + var += delta * delta2; + } + } + + template + __device__ void + Run(const XBufferType& x_buf_m_k, MeanBufferType& mean_buf_m, VarBufferType& var_buf_m) + { + // FIXME - Better naming for var_buf_m + + static_for<0, thread_x_length_k, 1>{}([&](auto iK) { + if(cur_count_ < max_count_) + { + ++cur_count_; + + static_for<0, thread_x_length_m, 1>{}([&](auto iM) { + constexpr index_t out_offset = + mean_var_thread_desc_m.CalculateOffset(make_tuple(iM)); + + constexpr auto in_offset = + x_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + Update(mean_buf_m(Number{}), + var_buf_m(Number{}), + x_buf_m_k[Number{}]); + }); + } + }); + }; + + int cur_count_; + int max_count_; +}; + +template +struct ThreadwiseWelfordMerge +{ + static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{}; + static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{}; + + static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{}); + static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{}); + static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{}); + + static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); + + __device__ static void + Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b) + { + int count = count_a + count_b; + T count_b_over_count = count == 0 ? type_convert(0) : type_convert(count_b) / count; + T delta = mean_b - mean_a; + mean_a += delta * count_b_over_count; + var_a += var_b + delta * delta * count_a * count_b_over_count; + count_a = count; + } + + template + __device__ static void Run(const SrcMeanBufferType& src_mean_buf, + const SrcVarBufferType& src_var_buf, + const SrcCountBufferType& src_count_buf, + DstMeanBufferType& dst_mean_buf, + DstVarBufferType& dst_var_buf, + DstCountBufferType& dst_count_buf) + { + static_for<0, src_length_m, 1>{}([&](auto iM) { + static_for<0, src_length_k, 1>{}([&](auto iK) { + constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK)); + + Merge(dst_mean_buf(iM), + dst_var_buf(iM), + dst_count_buf(iM), + src_mean_buf[Number{}], + src_var_buf[Number{}], + src_count_buf[Number{}]); + }); + + if constexpr(GetActualVariance) + { + dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM]; + }; + }); + }; +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp new file mode 100755 index 000000000000..409bb9f67465 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp @@ -0,0 +1,538 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_gemm_dpp.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +enum struct DppInstr +{ + dpp8_f16_1x32x2 = 0, + dpp8_f16_2x16x2, + dpp8_f16_2x32x2, + dpp8_f16_4x16x2, + dpp8_f16_4x32x2, + dpp8_f16_8x16x2, + dpp8_f16_8x32x2, + dpp8_f16_16x16x2, + dpp8_f16_32x8x2 +}; + +/** + * Structure representing DPP GEMM executed by a single wavefront. + * + * Each structure instantiation must contain the following fields: + * - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the + * number of threads in a wavefront; + * - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier, + * it's 8 in case of DPP8; + * - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM + * operation; + * - m_per_lanegroup - size along M dimension that is processed by a single lanegroup; + * - n_per_lanegroup - size along N dimension that is processed by a single lanegroup; + * - m_per_thread - size along M dimension of the tile calculated by a single thread; + * - n_per_thread - size along N dimension of the tile calculated by a single thread; + * - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation; + * - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers. + * + * Not all the combinarions are supported now, for current restrictions see the static asserts + * in the DppSelector's contructor. + */ +template +struct dpp_type; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 32; + static constexpr index_t n_per_wave = 8; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 8; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 16; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 8; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 8; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 4; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 4; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 4; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 1; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 32; + static constexpr index_t m_per_lanegroup = 2; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 2; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template <> +struct dpp_type +{ + static constexpr index_t wave_size = 32; + static constexpr index_t lanegroup_size = 8; + static constexpr index_t m_per_wave = 2; + static constexpr index_t n_per_wave = 16; + static constexpr index_t m_per_lanegroup = 1; + static constexpr index_t n_per_lanegroup = 8; + static constexpr index_t m_per_thread = 1; + static constexpr index_t n_per_thread = 1; + static constexpr index_t k_per_dpp = 2; + static constexpr bool share_a = true; + using BaseType = half_t; + + template + __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const + { + dpp8::DppLanegroupGemm{} + .Run(a, b, reg_c); + } +}; + +template +struct DppSelector +{ + template + static constexpr auto GetDpp(); + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x32x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_8x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_16x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_32x8x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_1x32x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x32x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_2x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x16x2; + } + + template <> + constexpr auto GetDpp() + { + return DppInstr::dpp8_f16_4x32x2; + } + + static constexpr auto selected_dpp = dpp_type()>{}; + + __host__ __device__ constexpr DppSelector() + { + static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0); + static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0); + + static_assert(selected_dpp.k_per_dpp % 2 == 0); + + static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0); + constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size; + constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave; + constexpr index_t num_dpp_c_elems = + selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup; + static_assert(num_wave_c_elems % num_dpp_c_elems == 0); + static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems); + + if constexpr(selected_dpp.share_a) + { + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + } + else + { + static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0); + static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread == + selected_dpp.lanegroup_size); + static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread); + } + + // Below checks come from the restrictions of the current implementation, could be removed + // in the future when the implementation is more generalized. + static_assert(selected_dpp.share_a); + static_assert(selected_dpp.n_per_thread == 1); + static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread); + static_assert(selected_dpp.n_per_lanegroup == + selected_dpp.n_per_thread * selected_dpp.lanegroup_size); + } + + static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; } +}; + +template +struct DppGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __host__ __device__ constexpr DppGemm() + { + static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp."); + } + + __device__ static constexpr index_t GetRegSizePerDpp() + { + return MPerDpp * NPerDpp / dpp_instr.wave_size; + } + + template + __device__ void + Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "base BaseType must be double, float, half, bfloat16, and int8_t!"); + + static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) { + dpp_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + }); + } + + __device__ static auto GetLaneIdInWave() + { + return get_thread_local_1d_id() % dpp_instr.wave_size; + } + + __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; } + + __device__ static auto GetLaneIdInLaneGroup() + { + return get_thread_local_1d_id() % dpp_instr.lanegroup_size; + } + + __device__ static auto GetLaneGroupIdInWave() + { + return GetLaneIdInWave() / dpp_instr.lanegroup_size; + } + + __device__ static auto GetDppOpIdx() + { + const auto lanegroupId = GetLaneGroupIdInWave(); + + constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor( + make_tuple( + make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup, + dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))), + make_tuple(Sequence<0, 1>{}), + make_tuple(Sequence<0>{})); + + const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex( + make_multi_index(lanegroupId)); + + const auto m_dpp_idx = dpp_idx[I0]; + const auto n_dpp_idx = dpp_idx[I1]; + + return make_tuple(m_dpp_idx, n_dpp_idx); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M() + { + const auto laneId = get_thread_local_1d_id(); + const auto wave_row = laneId / dpp_instr.n_per_wave; + auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup(); + return make_tuple(0, m_idx % dpp_instr.m_per_wave); + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N() + { + const auto laneId = get_thread_local_1d_id(); + return make_tuple(0, laneId % dpp_instr.n_per_wave); + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + const auto dpp_op_idx = GetDppOpIdx(); + + const auto m_dpp_op_idx = dpp_op_idx[I0]; + const auto n_dpp_op_idx = dpp_op_idx[I1]; + + index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup(); + index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup; + + return CIndex{m_offset, n_offset}; + } + + static constexpr auto dpp = DppSelector{}; + + static constexpr auto dpp_instr = dpp.selected_dpp; + + static constexpr auto K0PerDpp = 1; + static constexpr auto K1PerDpp = dpp.GetK1PerDpp(); + + __host__ __device__ static constexpr auto GetCMNThreadBlkLengths() + { + return make_tuple(Number{}, Number{}); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp new file mode 100755 index 000000000000..a436afd3951a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/smfmac_xdlops_gemm.hpp @@ -0,0 +1,433 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_smfmac.hpp" + +namespace ck { + +enum struct SmfmacInstr +{ + smfmac_f32_16x16x32f16 = 0, + smfmac_f32_32x32x16f16, + smfmac_f32_16x16x32bf16, + smfmac_f32_32x32x16bf16, +}; + +template +struct smfmac_type; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32f16::Run( + a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16f16::Run( + a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_16x16x32bf16::Run( + a, b, idx, reg_c); + } +}; + +template <> +struct smfmac +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, const index_t& idx, FloatC& reg_c) const + { + intrin_smfmac_f32_32x32x16bf16::Run( + a, b, idx, reg_c); + } +}; + +template +struct SmfmacSelector +{ + template + static constexpr auto GetSmfmac(); + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16f16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_16x16x32bf16; + } + + template <> + static constexpr auto GetSmfmac() + { + return SmfmacInstr::smfmac_f32_32x32x16bf16; + } + + static constexpr auto selected_smfmac = + smfmac_type()>{}; + + __host__ __device__ constexpr SmfmacSelector() + { + static_assert(selected_smfmac.group_size * selected_smfmac.num_groups_per_blk == + selected_smfmac.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_smfmac.num_threads_per_blk == selected_smfmac.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.num_input_blks == + selected_smfmac.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_smfmac.num_output_blks == selected_smfmac.num_input_blks || + selected_smfmac.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_smfmac.num_regs_per_blk * selected_smfmac.wave_size == + selected_smfmac.m_per_blk * selected_smfmac.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_smfmac.is_k_reduction || + (selected_smfmac.num_input_blks == selected_smfmac.num_output_blks), + "is_k_reduction wrong!"); + } + + static constexpr index_t GetKPerXdlops() + { + return (selected_smfmac.is_k_reduction ? selected_smfmac.num_input_blks : 1) * + selected_smfmac.k_per_blk; + } + + static constexpr index_t GetK1PerXdlops() { return selected_smfmac.k_per_blk; } +}; + +template +struct SparseXdlopsGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return smfmac_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (smfmac_instr.m_per_blk * smfmac_instr.n_per_blk * smfmac_instr.num_output_blks); + } + + __host__ __device__ constexpr SparseXdlopsGemm() + { + static_assert(NPerXdlops == 16 || NPerXdlops == 32, + "Only support GemmNPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(MPerXdlops == 16 || MPerXdlops == 32, + "Only support GemmMPerXdlops == 16 or 32 for smfmac xdlops"); + + static_assert(KPack % smfmac_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(smfmac_instr.num_groups_per_blk, + smfmac_instr.num_input_blks, + smfmac_instr.group_size)), + make_pass_through_transform(smfmac_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / smfmac_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return smfmac_instr.wave_size; } + + template + __device__ void + Run(const FloatA& p_a_wave, const FloatB& p_b_wave, const Idx& idx, FloatC& p_c_thread) const + { + static_assert(is_same::value || is_same::value, + "base base_type must be half or bfloat16!"); + + static_for<0, KPack / smfmac_instr.k_per_blk, 1>{}([&](auto k) { + smfmac_instr.template run( + p_a_wave[k], p_b_wave[k], idx[k / 4], p_c_thread); + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % smfmac_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, smfmac_instr.num_input_blks, smfmac_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(smfmac_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * smfmac_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * smfmac_instr.m_per_blk + blk_id * smfmac_instr.group_size; + + return CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return CIndex4D{I0, blk_id, I0, blk_td}; + } + + static constexpr auto smfmac = + SmfmacSelector{}; + + static constexpr auto smfmac_instr = smfmac.selected_smfmac; + + static constexpr auto KPerXdlops = smfmac.GetKPerXdlops(); + static constexpr auto K1PerXdlops = smfmac.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp new file mode 100755 index 000000000000..429df2413fce --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp @@ -0,0 +1,874 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_wmma.hpp" + +namespace ck { + +enum struct WmmaInstr +{ + // gfx11 + wmma_f32_16x16x16_f16 = 0, + wmma_f32_16x16x16_bf16, + wmma_f16_16x16x16_f16, + wmma_bf16_16x16x16_bf16, + wmma_i32_16x16x16_iu8, + wmma_i32_16x16x16_iu4, + // gfx12 + wmma_f32_16x16x16_f16_gfx12, + wmma_f32_16x16x16_bf16_gfx12, + wmma_i32_16x16x16_iu8_gfx12, + wmma_f32_16x16x16_f8f8_gfx12, + wmma_f32_16x16x16_f8bf8_gfx12, + wmma_f32_16x16x16_bf8f8_gfx12, + wmma_f32_16x16x16_bf8bf8_gfx12, +}; + +/* + * WMMA Wave Tile Always MxNxK = 16x16x16 + * WAVE32 + ----------------------------------- + |RC0| | | | | | | | | | | | | | | | SubGroup 0 + |RC1| | | | | | | | | | | | | | | | + |RC2| | | | | | | | | | | | | | | | + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC6| | | | | | | | | | | | | | | | + |RC7| | | | | | | | | | | | | | | | + ----------------------------------- + | | | | | | | | | | | | | | | | | SubGroup 1 + | | | | | | | | | | | | | | | | | + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + | | | | | | | | | | | | | | | | | + ----------------------------------- + + + * WAVE64 + ----------------------------------- + |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0 + |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1| + |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5| + |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1 + | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3| + | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2 + | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4| + | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7| + | | | | | | | | | | | | | | | | | + ----------------------------------- + | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3 + | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6| + | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3| + | | | | | | | | | | | | | | | | | + ----------------------------------- + +* RC = Register for storing accumalted result +* T = Thread ID +*/ + +template +struct wmma_type +{ +}; + +// A-swizzled +template +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed on gfx11, Will be wave mode dependent for future architectures + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f32_16x16x16_f16_w64::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f32_16x16x16_bf16_w64::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 2; + static constexpr index_t acc_pack_number = 2; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_f16_16x16x16_f16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_f16_16x16x16_f16_w64::Run(a, b, reg_c); + } + } +}; +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 2; + static constexpr index_t acc_pack_number = 2; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_bf16_16x16x16_bf16_w32::Run(a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_bf16_16x16x16_bf16_w64::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t src_a_data_size = 2; + static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = + m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32::Run( + a, b, reg_c); + } + else if constexpr(wave_size == 64) + { + intrin_wmma_i32_16x16x16_iu8_w64::Run( + a, b, reg_c); + } + } +}; + +// gfx12 + +// A-swizzled +template +struct wmma_type> +{ + // Absolute fixing property + // * Data Pixel + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + // static constexpr index_t acc_data_size = 4; + // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // * Fixed for gfx11, Will be wave mode dependent on gfx12 + // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4; + // * num_acc_vgprs_per_wave alone M direction + // * num_subgroups alone M direction + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_f16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_f32_16x16x16_bf16_w32_gfx12::Run(a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + // static constexpr index_t src_a_data_size = 2; + // static constexpr index_t src_b_data_size = 2; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4; + // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { + intrin_wmma_i32_16x16x16_iu8_w32_gfx12::Run( + a, b, reg_c); + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct wmma_type> +{ + // Absolute fixing property + static constexpr index_t m_per_wmma = 16; + static constexpr index_t n_per_wmma = 16; + static constexpr index_t k_per_wmma = 16; + static constexpr index_t acc_data_size = 4; + static constexpr index_t acc_pack_number = 1; + static constexpr index_t num_thread_per_subgroups = n_per_wmma; + + // Wave mode dependent propety + static constexpr index_t wave_size = Number{}; + static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size; + static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + static_assert(wave_size == 32, "only support wave32 for gfx12 wmma"); + if constexpr(wave_size == 32) + { +#ifdef __gfx12__ + intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12::Run(a, b, reg_c); +#else + ignore = a; + ignore = b; + ignore = reg_c; +#endif + } + } +}; + +template +struct WmmaSelector +{ + template + static constexpr auto GetWmma(); + + template <> + constexpr auto GetWmma() + { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_f16_gfx12; +#else + return WmmaInstr::wmma_f32_16x16x16_f16; +#endif + } + + template <> + constexpr auto GetWmma() + { +#ifdef __gfx12__ + return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12; +#else + return WmmaInstr::wmma_f32_16x16x16_bf16; +#endif + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f16_16x16x16_f16; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_bf16_16x16x16_bf16; + } + + template <> + constexpr auto GetWmma() + { +#ifdef __gfx12__ + return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12; +#else + return WmmaInstr::wmma_i32_16x16x16_iu8; +#endif + } + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_i32_16x16x16_iu4; + } +#endif + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12; + } + + template <> + constexpr auto GetWmma() + { + return WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12; + } + + // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround + static constexpr auto selected_wmma = + wmma_type(), Number<32>{}>{}; + + __host__ __device__ constexpr WmmaSelector() + { + static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16"); + + static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave * + selected_wmma.acc_data_size * selected_wmma.acc_pack_number == + selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4, + "WRONG! Invalid Number of Accumulator Register"); + } +}; + +template +struct WmmaGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex3D = MultiIndex<3>; + + __host__ __device__ constexpr WmmaGemm() + { + static_assert(NPerWmma == 16 && MPerWmma == 16, + "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"); + + static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma"); + } + + // WMMA output supporting C = A * B + // Vector Write + // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave + template + __host__ __device__ static constexpr auto + MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + { + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_unmerge_transform(make_tuple(Number{}, + Number{})), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2, 6>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + } + + // Transposed WMMA Output C' = B' * A' + template + __host__ __device__ static constexpr auto + MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs( + const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA& + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma) + { + const auto MBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0); + const auto NBlockxRepeat = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3); + const auto MWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1); + const auto NWave = + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma, + make_tuple( + make_pass_through_transform(MBlockxRepeat), + make_pass_through_transform(MWave), + make_pass_through_transform(Number{}), + make_pass_through_transform(NBlockxRepeat), + make_pass_through_transform(NWave), + make_unmerge_transform(make_tuple(Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6>{})); + } + + __device__ static constexpr index_t GetRegSizePerWmma() + { + return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number; + } + + __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert( + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + (is_same::value && is_same::value && + is_same::value) || + ((is_same::value || is_same::value) && + (is_same::value || is_same::value) && + is_same::value) || +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 + (is_same::value && is_same::value && + is_same::value) || +#endif + false, + "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " + "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!"); + static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + wmma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + wmma_instr.template run(p_b_wave[k], p_a_wave[k], p_c_thread); + } + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; } + + __device__ static auto GetSubGroupId() + { + static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups == + wmma_instr.wave_size, + ""); + return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups; + } + + __device__ static auto GetLaneIdUnderSubGroup() + { + return GetLaneId() % wmma_instr.num_thread_per_subgroups; + } + __device__ static auto GetSwizzledLaneIdLow() + { + return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else + return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow(); +#endif + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { +#ifdef __gfx12__ + return GetLaneIdUnderSubGroup(); +#else + return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup(); +#endif + } + + __device__ static CIndex GetBeginOfThreadBlk() + { + index_t n_offset = GetLaneIdUnderSubGroup(); + index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave; + + return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; + } + + __device__ static CIndex3D GetBeginOfThreadBlk3D() + { + index_t n_offset = GetLaneIdUnderSubGroup(); + index_t m_offset = GetSubGroupId(); + + return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0}; + } + + static constexpr auto wmma = + WmmaSelector{}; + static constexpr auto wmma_instr = wmma.selected_wmma; + + __host__ __device__ static constexpr auto + GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths() + { + return make_tuple(I1, + I1, + Number{}, + Number{}); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp new file mode 100755 index 000000000000..64d7f92750cc --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -0,0 +1,1752 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/amd_xdlops.hpp" + +namespace ck { +/** + * @brief Define matrix data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_data_type() +{ + using U = element_type_t; + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Define scale data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_scale_type() +{ + return is_same_v; +} + +/** + * @brief Combination of data types that have hardware support for MX GEMMs + */ +template +static constexpr bool scale_mfma_hw_support() +{ + return is_scale_mfma_data_type() && is_scale_mfma_data_type() && + is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); +} + +enum struct MfmaInstr +{ + mfma_f32_32x32x1xf32 = 0, + mfma_f32_16x16x1xf32, + mfma_f32_4x4x1xf32, + mfma_f32_32x32x2xf32, + mfma_f32_16x16x4xf32, + mfma_f32_32x32x4f16, + mfma_f32_16x16x4f16, + mfma_f32_4x4x4f16, + mfma_f32_32x32x8f16, + mfma_f32_16x16x16f16, + mfma_f32_32x32x8bf16_1k, + mfma_f32_16x16x16bf16_1k, + mfma_f32_32x32x4bf16, + mfma_f32_16x16x8bf16, + mfma_i32_32x32x8i8, + mfma_i32_16x16x16i8, + mfma_i32_32x32x16i8, + mfma_i32_16x16x32i8, + mfma_f64_16x16x4f64, + mfma_f32_32x32x16f8f8, + mfma_f32_16x16x32f8f8, + mfma_f32_32x32x16bf8bf8, + mfma_f32_16x16x32bf8bf8, + mfma_f32_32x32x16f8bf8, + mfma_f32_16x16x32f8bf8, + mfma_f32_32x32x16bf8f8, + mfma_f32_16x16x32bf8f8, + mfma_f32_32x32x16f16, + mfma_f32_16x16x32f16, + mfma_f32_32x32x16bf16, + mfma_f32_16x16x32bf16, + mfma_i32_32x32x32i8, + mfma_i32_16x16x64i8, + mfma_f32_32x32x64f8f6f4, + mfma_f32_16x16x128f8f6f4, + mfma_scale_f32_32x32x64f8f6f4, + mfma_scale_f32_16x16x128f8f6f4 +}; + +template +struct mfma_type; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + } +}; + +// treat 4x4x1 as a single-blk 4x64 mfma +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x8bf16_1k::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x16bf16_1k::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4bf16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x8bf16::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x8i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x16i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x16i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x32i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_32x32x32i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 16; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_i32_16x16x64i8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 1; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f64_16x16x4f64::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32f8bf8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x16bf8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 8; + static constexpr bool is_k_reduction = true; + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x32bf8f8::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x64f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x128f8f6f4::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 32; // n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, + const ScaleA& scale_a, + const FloatB& b, + const ScaleB& scale_b, + FloatC& reg_c) const + { + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( + a, bit_cast(scale_a), b, bit_cast(scale_b), reg_c); + } +}; + +template <> +struct mfma_type +{ + // clang-format off + static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk + static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size + static constexpr index_t num_threads_per_blk = 16; // == n_per_blk + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk + static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks + static constexpr bool is_k_reduction = true; // ??? + // clang-format on + + template + __device__ void run(const FloatA& a, + const ScaleA& scale_a, + const FloatB& b, + const ScaleB& scale_b, + FloatC& reg_c) const + { + + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( + a, bit_cast(scale_a), b, bit_cast(scale_b), reg_c); + } +}; + +template +struct MfmaSelector +{ + template + static constexpr auto GetMfma(); + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f64_16x16x4f64; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x1xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x2xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16f16; +#else + return MfmaInstr::mfma_f32_32x32x8f16; +#endif + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x8f16; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32f16; +#else + return MfmaInstr::mfma_f32_16x16x16f16; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x16f16; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4f16; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x16bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_32x32x8bf16_1k; +#else + return MfmaInstr::mfma_f32_32x32x4bf16; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x32bf16; +#elif defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(CK_USE_AMD_MFMA_BF16_1K_OP) + return MfmaInstr::mfma_f32_16x16x16bf16_1k; +#else + return MfmaInstr::mfma_f32_16x16x8bf16; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_i32_32x32x32i8; +#elif defined(__gfx942__) + return MfmaInstr::mfma_i32_32x32x16i8; +#else + return MfmaInstr::mfma_i32_32x32x8i8; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx942__) || defined(__gfx950__) + return MfmaInstr::mfma_i32_32x32x16i8; +#else + return MfmaInstr::mfma_i32_32x32x8i8; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_i32_16x16x64i8; +#elif defined(__gfx942__) + return MfmaInstr::mfma_i32_16x16x32i8; +#else + return MfmaInstr::mfma_i32_16x16x16i8; +#endif + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx942__) || defined(__gfx950__) + return MfmaInstr::mfma_i32_16x16x32i8; +#else + return MfmaInstr::mfma_i32_16x16x16i8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8f8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8f8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8f8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8f8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8bf8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8bf8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8bf8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16f8bf8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32f8bf8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32f8bf8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x16bf8f8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_32x32x64f8f6f4; +#else + return MfmaInstr::mfma_f32_32x32x16bf8f8; +#endif + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x32bf8f8; + } + + template <> + constexpr auto GetMfma() + { +#if defined(__gfx950__) + return MfmaInstr::mfma_f32_16x16x128f8f6f4; +#else + return MfmaInstr::mfma_f32_16x16x32bf8f8; +#endif + } + + static constexpr auto selected_mfma = mfma_type, + MPerXdlops, + NPerXdlops, + element_type_t, + is_single_rate_mfma, + is_scale_mfma>()>{}; + + __host__ __device__ constexpr MfmaSelector() + { + static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == + selected_mfma.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || + selected_mfma.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size == + selected_mfma.m_per_blk * selected_mfma.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_mfma.is_k_reduction || + (selected_mfma.num_input_blks == selected_mfma.num_output_blks), + "is_k_reduction wrong!"); + } + + static constexpr bool IsABroadcast() + { + static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast"); + return true; + } + + static constexpr index_t GetKPerXdlops() + { + return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) * + selected_mfma.k_per_blk; + } + + static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; } +}; + +template +struct XdlopsGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + using CIndex = MultiIndex<2>; + using CIndex4D = MultiIndex<4>; + + __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } + + __device__ static constexpr index_t GetNumXdlops() + { + return MPerXdlops * NPerXdlops / + (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks); + } + + __host__ __device__ constexpr XdlopsGemm() + { + static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 || + NPerXdlops == 64, + "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 || + MPerXdlops == 64, + "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); + + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); + } + + // XDL output supporting C = A * B + // M2_N2 -> M2_M3_M4_N2 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); + } + + // XDL output supporting C = A * B + // M3_N3 -> M3_M4_M5_N3 + template + __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3( + const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4); + const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_pass_through_transform(M2), + make_pass_through_transform(N2), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6, 7, 8>{}, + Sequence<9>{})); + } + + // transposed XDL output supporting C' = B' * A' + // M2_N2 -> M2_N2_N3_N4 + template + __host__ __device__ static constexpr auto + MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2) + { + const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3); + + return transform_tensor_descriptor( + c_desc_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, + Number{}, + Number{}))), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{})); + } + + template + __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2) + { + const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0); + const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1); + const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2); + const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3); + const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4); + + return transform_tensor_descriptor( + c_desc_g_m0_n0_m1_n1_m2_n2, + make_tuple(make_pass_through_transform(G), + make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5, 6, 7>{}, + Sequence<8>{})); + } + + __device__ static constexpr index_t GetRegSizePerXdlops() + { + return MPerXdlops * NPerXdlops / mfma_instr.wave_size; + } + + __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } + + template + __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const + { + static_assert( + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || + (is_same::value && is_same::value) || + (is_same::value && is_same::value), + "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); + + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + mfma_instr.template run( + p_a_wave[k], p_b_wave[k], p_c_thread); + } + else + { + mfma_instr.template run( + p_b_wave[k], p_a_wave[k], p_c_thread); + } + }); + } + + template + __device__ void Run(const FloatA& p_a_wave, + const ScaleA& a_scale_thread, + const FloatB& p_b_wave, + const ScaleB& b_scale_thread, + FloatC& p_c_thread) const + { + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + mfma_instr.template run( + p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread); + } + else + { + mfma_instr.template run( + p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread); + } + }); + } + + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; + + return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset}; + } + + __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */) + { + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td}; + } + + // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942- + // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t, + // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t) + static constexpr bool is_single_rate_mfma = + (((is_same::value || is_same::value) && + KPack <= 4) || + (is_same::value && KPack <= 8) || + ((is_same::value || is_same::value) && KPack < 32) || + is_same::value) + ? true + : false; + static constexpr auto mfma = MfmaSelector{}; + + static constexpr auto mfma_instr = mfma.selected_mfma; + + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() + { + return make_tuple( + Number{}, I1, Number{}, I1); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp new file mode 100755 index 000000000000..ea27a40ce3c7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm.hpp @@ -0,0 +1,288 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] +template +static auto MakeGridDescriptorPair(const std::vector& gs_ms_ns_lengths_vec, + const std::vector& gs_ms_ns_strides_vec) +{ + if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) + { + throw std::runtime_error("wrong! dimension must match input lengths"); + } + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto gs_ms_ns_lengths = + to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto gs_ms_ns_strides = + to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds); + + if constexpr(TensorSpec == device::TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } + else + { + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + // Note: This does not require padding as it only provides G offset calculation. Technically + // descriptor for only G is needed. Here we opt for backward compatibility purpose to return + // G_M_N + const auto grid_desc_g_mraw_nraw = + transform_tensor_descriptor(grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto c_ms_ns_lengths = to_tuple( + gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + gs_ms_ns_strides_vec, Number{}, Number{}); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + const auto grid_desc_mraw_nraw = transform_tensor_descriptor( + grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds - Number{}, nDimIds - Number{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } +} + +template + typename PerBlock_M_N_K_O, // Sequence<> + device::GemmSpecialization GemmSpec, + device::TensorSpecialization ASpec, + device::TensorSpecialization B0Spec, + device::TensorSpecialization B1Spec, + device::TensorSpecialization CSpec> +struct TransformBatchedContractionContractionToBatchedGemmGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0); + static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1); + static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2); + static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3); + static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4); + + static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0); + static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1); + static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2); + static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3); + + static constexpr auto matrix_padder = + device::GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, OPerBlock}; + + // + // A + // + static auto MakeAGridDescriptorPair(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return MakeGridDescriptorPair(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec); + } + + // TODO: rename to G_MRaw_KRaw + static auto MakeAGridDescriptor_G_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; + } + static auto MakeAGridDescriptor_M_K(const std::vector& a_gs_ms_ks_lengths_vec, + const std::vector& a_gs_ms_ks_strides_vec) + { + return matrix_padder.PadADescriptor_M_K( + MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // B (alias of B0) + // + static auto MakeB0GridDescriptorPair(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + return MakeGridDescriptorPair(b0_gs_ns_ks_lengths_vec, + b0_gs_ns_ks_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + static auto MakeB0GridDescriptor_G_N_K(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; + } + static auto MakeB0GridDescriptor_N_K(const std::vector& b0_gs_ns_ks_lengths_vec, + const std::vector& b0_gs_ns_ks_strides_vec) + { + // alias of matrix_padder.PadB0Descriptor_N_K + return matrix_padder.PadBDescriptor_N_K( + MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // B1 + // + static auto MakeB1GridDescriptorPair(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + return MakeGridDescriptorPair(b1_gs_os_ns_lengths_vec, + b1_gs_os_ns_strides_vec); + } + + // TODO: rename to G_NRaw_KRaw + static auto MakeB1GridDescriptor_G_N_K(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; + } + static auto MakeB1GridDescriptor_N_K(const std::vector& b1_gs_os_ns_lengths_vec, + const std::vector& b1_gs_os_ns_strides_vec) + { + // alias of matrix_padder.PadB1Descriptor_O_N + return matrix_padder.PadB1Descriptor_N_K( + MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + // + // C + // + static auto MakeCGridDescriptorPair(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return MakeGridDescriptorPair(c_gs_ms_os_lengths_vec, + c_gs_ms_os_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + static auto MakeCGridDescriptor_G_M_N(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; + } + static auto MakeCGridDescriptor_M_N(const std::vector& c_gs_ms_os_lengths_vec, + const std::vector& c_gs_ms_os_strides_vec) + { + return matrix_padder.PadCDescriptor_M_N( + MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp new file mode 100755 index 000000000000..56181d38c873 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_contraction_to_gemm_arraybase.hpp @@ -0,0 +1,391 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +// assume C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] +template +__host__ __device__ static auto +MakeGridDescriptorPair(const std::array& gs_ms_ns_lengths_vec, + const std::array& gs_ms_ns_strides_vec) +{ + // if(!(gs_ms_ns_lengths_vec.size() == NumDimG + NumDimM + NumDimN && + // gs_ms_ns_strides_vec.size() == NumDimG + NumDimM + NumDimN)) + // { + // throw std::runtime_error("wrong! dimension must match input lengths"); + // } + + const auto to_tuple = [&](auto& vec, auto start, auto end) { + return generate_tuple([&](auto i) { return vec[start + i]; }, Number{}); + }; + + const auto gs_ms_ns_lengths = + to_tuple(gs_ms_ns_lengths_vec, Number<0>{}, Number{}); + const auto gs_ms_ns_strides = + to_tuple(gs_ms_ns_strides_vec, Number<0>{}, Number{}); + + // dimension Ids for G0, G1, ... + constexpr auto gDimIds = typename arithmetic_sequence_gen<0, NumDimG, 1>::type{}; + + // dimension Ids for M0, M1, ... + constexpr auto mDimIds = + typename arithmetic_sequence_gen::type{}; + + // dimension Ids for N0, N1, ... + constexpr auto nDimIds = + typename arithmetic_sequence_gen::type{}; + + // lengths for G0, G1, ... + const auto gLengths = get_container_subset(gs_ms_ns_lengths, gDimIds); + + // lengths for M0, M1, ... + const auto mLengths = get_container_subset(gs_ms_ns_lengths, mDimIds); + + // lengths for N0, N1, ... + const auto nLengths = get_container_subset(gs_ms_ns_lengths, nDimIds); + + if constexpr(TensorSpec == device::TensorSpecialization::Packed) + { + auto G = container_reduce(gLengths, math::multiplies{}, Number<1>{}); + auto M = container_reduce(mLengths, math::multiplies{}, Number<1>{}); + auto N = container_reduce(nLengths, math::multiplies{}, Number<1>{}); + const auto grid_desc_g_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(G, M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + const auto grid_desc_mraw_nraw = make_naive_tensor_descriptor( + make_tuple(M, N), + make_tuple(gs_ms_ns_strides[Number{}], + gs_ms_ns_strides[Number{}])); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } + else + { + // naive tensor C[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...] + const auto grid_desc_gs_ms_ns = + make_naive_tensor_descriptor(gs_ms_ns_lengths, gs_ms_ns_strides); + + // transformed tensor C[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + // Note: This does not require padding as it only provides G offset calculation. Technically + // descriptor for only G is needed. Here we opt for backward compatibility purpose to return + // G_M_N + const auto grid_desc_g_mraw_nraw = + transform_tensor_descriptor(grid_desc_gs_ms_ns, + make_tuple(make_merge_transform(gLengths), + make_merge_transform(mLengths), + make_merge_transform(nLengths)), + make_tuple(gDimIds, mDimIds, nDimIds), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto c_ms_ns_lengths = to_tuple( + gs_ms_ns_lengths_vec, Number{}, Number{}); + const auto c_ms_ns_strides = to_tuple( + gs_ms_ns_strides_vec, Number{}, Number{}); + + // transformed tensor C[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * + // N2 * ...] + const auto grid_desc_ms_ns = make_naive_tensor_descriptor(c_ms_ns_lengths, c_ms_ns_strides); + + const auto grid_desc_mraw_nraw = transform_tensor_descriptor( + grid_desc_ms_ns, + make_tuple(make_merge_transform(mLengths), make_merge_transform(nLengths)), + make_tuple(mDimIds - Number{}, nDimIds - Number{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return std::make_pair(grid_desc_g_mraw_nraw, grid_desc_mraw_nraw); + } +} + +template + typename PerBlock_M_N_K_O, // Sequence<> + device::GemmSpecialization GemmSpec, + device::TensorSpecialization ASpec, + device::TensorSpecialization B0Spec, + device::TensorSpecialization B1Spec, + device::TensorSpecialization CSpec> +struct TransformBatchedContractionContractionToBatchedGemmGemm_Wmma +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + + static constexpr index_t NumDimG = NumDims_G_M_N_K_O::At(I0); + static constexpr index_t NumDimM = NumDims_G_M_N_K_O::At(I1); + static constexpr index_t NumDimN = NumDims_G_M_N_K_O::At(I2); + static constexpr index_t NumDimK = NumDims_G_M_N_K_O::At(I3); + static constexpr index_t NumDimO = NumDims_G_M_N_K_O::At(I4); + + static constexpr index_t MPerBlock = PerBlock_M_N_K_O::At(I0); + static constexpr index_t NPerBlock = PerBlock_M_N_K_O::At(I1); + static constexpr index_t KPerBlock = PerBlock_M_N_K_O::At(I2); + static constexpr index_t OPerBlock = PerBlock_M_N_K_O::At(I3); + + static constexpr auto matrix_padder = + device::GemmGemmPadder{ + MPerBlock, NPerBlock, KPerBlock, OPerBlock}; + + // + // A + // + __host__ __device__ static auto MakeAGridDescriptorPair( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeGridDescriptorPair(a_gs_ms_ks_lengths_vec, + a_gs_ms_ks_strides_vec); + } + + // TODO: rename to G_MRaw_KRaw + __host__ __device__ static auto MakeAGridDescriptor_G_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).first; + } + __host__ __device__ static auto MakeAGridDescriptor_M_K( + const std::array& a_gs_ms_ks_lengths_vec, + const std::array& a_gs_ms_ks_strides_vec) + { + return matrix_padder.PadADescriptor_M_K( + MakeAGridDescriptorPair(a_gs_ms_ks_lengths_vec, a_gs_ms_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k, const Number& AK1) + { + const auto M = a_grid_desc_m_k.GetLength(I0); + const auto K = a_grid_desc_m_k.GetLength(I1); + + const auto AK0 = K / AK1; + + return transform_tensor_descriptor(a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeAGridDescriptor_AKWmma_MBlockRepeat_MWaves_AK0PerWmma_AKRow_MPerWmma_AK1( + const AGridDesc_M_K& a_grid_desc_m_k, + const WmmaK&, + const MRepeat&, + const MWaves&, + const MPerWmma&, + const AK1&) + { + const auto M0 = a_grid_desc_m_k.GetLength(I0) / MPerBlock; + const auto K = a_grid_desc_m_k.GetLength(I1); + const auto AKWmma = K / WmmaK{}; + constexpr auto AKRow = 2; + constexpr auto AK0PerWmma = WmmaK{} / AKRow / AK1{}; + + return transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform( + make_tuple(AKWmma, Number{}, Number{}, AK1{})), + make_unmerge_transform(make_tuple(M0 * MRepeat{}, MWaves{}, MPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B (alias of B0) + // + __host__ __device__ static auto MakeB0GridDescriptorPair( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeGridDescriptorPair(b0_gs_ns_ks_lengths_vec, + b0_gs_ns_ks_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeB0GridDescriptor_G_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + return MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).first; + } + __host__ __device__ static auto MakeB0GridDescriptor_N_K( + const std::array& b0_gs_ns_ks_lengths_vec, + const std::array& b0_gs_ns_ks_strides_vec) + { + // alias of matrix_padder.PadB0Descriptor_N_K + return matrix_padder.PadBDescriptor_N_K( + MakeB0GridDescriptorPair(b0_gs_ns_ks_lengths_vec, b0_gs_ns_ks_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k, const Number& BK1) + { + const auto N = b_grid_desc_n_k.GetLength(I0); + const auto K = b_grid_desc_n_k.GetLength(I1); + + const auto BK0 = K / BK1; + + return transform_tensor_descriptor(b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB0GridDescriptor_BKWmma_LBlockRepeat_LWaves_BK0PerWmma_BKRow_LPerWmma_BK1( + const BGridDesc_L_K& b_grid_desc_l_k, + const WmmaK&, + const LRepeat&, + const LWaves&, + const LPerWmma&, + const BK1&) + { + const auto L0 = b_grid_desc_l_k.GetLength(I0) / NPerBlock; + const auto K = b_grid_desc_l_k.GetLength(I1); + const auto BKWmma = K / WmmaK{}; + constexpr auto BKRow = 2; + constexpr auto BK0PerWmma = WmmaK{} / BKRow / BK1{}; + + return transform_tensor_descriptor( + b_grid_desc_l_k, + make_tuple(make_unmerge_transform( + make_tuple(BKWmma, Number{}, Number{}, BK1{})), + make_unmerge_transform(make_tuple(L0 * LRepeat{}, LWaves{}, LPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // B1 + // + __host__ __device__ static auto MakeB1GridDescriptorPair( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeGridDescriptorPair(b1_gs_os_ns_lengths_vec, + b1_gs_os_ns_strides_vec); + } + + // TODO: rename to G_NRaw_KRaw + __host__ __device__ static auto MakeB1GridDescriptor_G_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + return MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).first; + } + __host__ __device__ static auto MakeB1GridDescriptor_N_K( + const std::array& b1_gs_os_ns_lengths_vec, + const std::array& b1_gs_os_ns_strides_vec) + { + // alias of matrix_padder.PadB1Descriptor_O_N + return matrix_padder.PadB1Descriptor_N_K( + MakeB1GridDescriptorPair(b1_gs_os_ns_lengths_vec, b1_gs_os_ns_strides_vec).second); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BK0_N_BK1(const B1GridDesc_N_K& b1_grid_desc_n_k, const Number& B1K1) + { + const auto N = b1_grid_desc_n_k.GetLength(I0); + const auto K = b1_grid_desc_n_k.GetLength(I1); + + const auto B1K0 = K / B1K1; + + return transform_tensor_descriptor( + b1_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(B1K0, B1K1)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + } + + template + __host__ __device__ static constexpr auto + MakeB1GridDescriptor_BLWmma_NBlockRepeat_NWaves__BL0PerWmma_BLRow_NPerWmma_BL1( + const BGridDesc_N_L& b_grid_desc_n_l, + const WmmaL&, + const NRepeat&, + const NWaves&, + const NPerWmma&, + const BL1&) + { + const auto N0 = b_grid_desc_n_l.GetLength(I0) / OPerBlock; + const auto L = b_grid_desc_n_l.GetLength(I1); + const auto BLWmma = L / WmmaL{}; + constexpr auto BLRow = 2; + constexpr auto BL0PerWmma = WmmaL{} / BLRow / BL1{}; + + return transform_tensor_descriptor( + b_grid_desc_n_l, + make_tuple(make_unmerge_transform( + make_tuple(BLWmma, Number{}, Number{}, BL1{})), + make_unmerge_transform(make_tuple(N0 * NRepeat{}, NWaves{}, NPerWmma{}))), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 3, 4, 6>{}, Sequence<1, 2, 5>{})); + } + + // + // C + // + __host__ __device__ static auto MakeCGridDescriptorPair( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeGridDescriptorPair(c_gs_ms_os_lengths_vec, + c_gs_ms_os_strides_vec); + } + + // TODO: rename to G_MRaw_NRaw + __host__ __device__ static auto MakeCGridDescriptor_G_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).first; + } + __host__ __device__ static auto MakeCGridDescriptor_M_N( + const std::array& c_gs_ms_os_lengths_vec, + const std::array& c_gs_ms_os_strides_vec) + { + return matrix_padder.PadCDescriptor_M_N( + MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp new file mode 100755 index 000000000000..977c622f069c --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp @@ -0,0 +1,1484 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" + +namespace ck { +namespace tensor_operation { + +template < + index_t NDimSpatial, + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization, + index_t AK1, + index_t BK1, + index_t GemmMPerBlock, + index_t GemmNPerBlock, + index_t GemmKPerBlock, + bool DoPadGemmM, + bool DoPadGemmN, + typename ALayout, + typename BLayout, + typename CLayout, + bool SplitN = false, + typename ADataType = float, + typename CDataType = float, + index_t NumGroupsToMerge = 1, + typename IndexType = index_t, + bool CTranspose = false> +struct TransformConvBwdDataToGemm_v1 +{ + private: + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = NonSpatialDimsNum; + static constexpr auto HIdx = + NDimSpatial == 2 ? NonSpatialDimsNum : Number{}; + static constexpr auto WIdx = + NDimSpatial == 2 ? Number{} : Number{}; + + static constexpr auto ZIdx = NonSpatialDimsNum; + static constexpr auto YIdx = + NDimSpatial == 2 ? NonSpatialDimsNum : Number{}; + static constexpr auto XIdx = + NDimSpatial == 2 ? Number{} : Number{}; + + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_k_wos_lengths, + const ConvDimsType& a_g_n_k_wos_strides, + const ConvDimsType& c_g_n_c_wis_lengths, + const ConvDimsType& c_g_n_c_wis_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_k_wos_lengths, a_g_n_k_wos_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_c_wis_lengths, c_g_n_c_wis_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_k_wos_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } + + public: + __host__ __device__ constexpr TransformConvBwdDataToGemm_v1() {} + + template + __host__ __device__ TransformConvBwdDataToGemm_v1( + const TransformConvBwdDataToGemm_v1Base& transform_conv_bwd_data_to_gemm_base) + : N_{static_cast(transform_conv_bwd_data_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_bwd_data_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_bwd_data_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_bwd_data_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_bwd_data_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_bwd_data_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_bwd_data_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_bwd_data_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_bwd_data_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_bwd_data_to_gemm_base.X_)}, + K_{static_cast(transform_conv_bwd_data_to_gemm_base.K_)}, + C_{static_cast(transform_conv_bwd_data_to_gemm_base.C_)}, + DiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.DiStride_)}, + HiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.HiStride_)}, + WiStride_{static_cast(transform_conv_bwd_data_to_gemm_base.WiStride_)}, + DoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.DoStride_)}, + HoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.HoStride_)}, + WoStride_{static_cast(transform_conv_bwd_data_to_gemm_base.WoStride_)}, + CStrideTensorB_{ + static_cast(transform_conv_bwd_data_to_gemm_base.CStrideTensorB_)}, + CStrideTensorC_{ + static_cast(transform_conv_bwd_data_to_gemm_base.CStrideTensorC_)}, + KStrideTensorA_{ + static_cast(transform_conv_bwd_data_to_gemm_base.KStrideTensorA_)}, + KStrideTensorB_{ + static_cast(transform_conv_bwd_data_to_gemm_base.KStrideTensorB_)}, + NStrideTensorA_{ + static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorA_)}, + NStrideTensorC_{ + static_cast(transform_conv_bwd_data_to_gemm_base.NStrideTensorC_)}, + ConvStrideD_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_bwd_data_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{ + static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{ + static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{ + static_cast(transform_conv_bwd_data_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_bwd_data_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_bwd_data_to_gemm_base.InRightPadW_)}, + IdxZTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxZTilde_)}, + IdxYTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxYTilde_)}, + IdxXTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.IdxXTilde_)}, + GcdStrideDilationD_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationD_)}, + GcdStrideDilationH_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationH_)}, + GcdStrideDilationW_{ + static_cast(transform_conv_bwd_data_to_gemm_base.GcdStrideDilationW_)}, + ZTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.ZTilde_)}, + YTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.YTilde_)}, + XTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.XTilde_)}, + DTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.DTilde_)}, + HTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.HTilde_)}, + WTilde_{static_cast(transform_conv_bwd_data_to_gemm_base.WTilde_)}, + ZDot_{static_cast(transform_conv_bwd_data_to_gemm_base.ZDot_)}, + YDot_{static_cast(transform_conv_bwd_data_to_gemm_base.YDot_)}, + XDot_{static_cast(transform_conv_bwd_data_to_gemm_base.XDot_)}, + batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_} + { + } + + template + __host__ __device__ + TransformConvBwdDataToGemm_v1(const ConvDimsType& a_g_n_k_wos_lengths, + const ConvDimsType& a_g_n_k_wos_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_c_wis_lengths, + const ConvDimsType& c_g_n_c_wis_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads, + const ConvSpatialDimsType& tildes, + const index_t batch_k = 1) + : Hi_{c_g_n_c_wis_lengths[HIdx]}, + Wi_{c_g_n_c_wis_lengths[WIdx]}, + Ho_{a_g_n_k_wos_lengths[HIdx]}, + Wo_{a_g_n_k_wos_lengths[WIdx]}, + Y_{b_g_k_c_xs_lengths[YIdx]}, + X_{b_g_k_c_xs_lengths[XIdx]}, + K_{a_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + HiStride_{c_g_n_c_wis_strides[HIdx]}, + WiStride_{c_g_n_c_wis_strides[WIdx]}, + HoStride_{a_g_n_k_wos_strides[HIdx]}, + WoStride_{a_g_n_k_wos_strides[WIdx]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + CStrideTensorC_{c_g_n_c_wis_strides[I2]}, + KStrideTensorA_{a_g_n_k_wos_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + NStrideTensorA_{a_g_n_k_wos_strides[I1]}, + NStrideTensorC_{c_g_n_c_wis_strides[I1]}, + ConvStrideH_{conv_filter_strides[HIdx - NonSpatialDimsNum]}, + ConvStrideW_{conv_filter_strides[WIdx - NonSpatialDimsNum]}, + ConvDilationH_{conv_filter_dilations[HIdx - NonSpatialDimsNum]}, + ConvDilationW_{conv_filter_dilations[WIdx - NonSpatialDimsNum]}, + InLeftPadH_{input_left_pads[HIdx - NonSpatialDimsNum]}, + InLeftPadW_{input_left_pads[WIdx - NonSpatialDimsNum]}, + InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]}, + InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]}, + IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]}, + IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}, + batch_k_{batch_k} + { + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); + + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_k_wos_lengths, a_g_n_k_wos_strides, c_g_n_c_wis_lengths, c_g_n_c_wis_strides); + } + else + { + N_ = c_g_n_c_wis_lengths[I1]; + } + if constexpr(NDimSpatial == 3) + { + Di_ = c_g_n_c_wis_lengths[DIdx]; + Do_ = a_g_n_k_wos_lengths[DIdx]; + Z_ = b_g_k_c_xs_lengths[ZIdx]; + DiStride_ = c_g_n_c_wis_strides[DIdx]; + DoStride_ = a_g_n_k_wos_strides[DIdx]; + ConvStrideD_ = conv_filter_strides[DIdx - NonSpatialDimsNum]; + ConvDilationD_ = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + InLeftPadD_ = input_left_pads[DIdx - NonSpatialDimsNum]; + InRightPadD_ = input_right_pads[DIdx - NonSpatialDimsNum]; + IdxZTilde_ = tildes[ZIdx - NonSpatialDimsNum]; + GcdStrideDilationD_ = math::gcd(ConvStrideD_, ConvDilationD_); + ZTilde_ = ConvStrideD_ / GcdStrideDilationD_; + DTilde_ = Do_ + math::integer_divide_ceil(ConvDilationD_ * (Z_ - I1), ConvStrideD_); + ZDot_ = math::integer_divide_ceil(Z_, ZTilde_); + } + else + { + Di_ = Do_ = Z_ = ZTilde_ = ConvStrideD_ = DTilde_ = ZDot_ = 1; + InLeftPadD_ = InRightPadD_ = DiStride_ = DoStride_ = IdxZTilde_ = 0; + } + + GcdStrideDilationH_ = math::gcd(ConvStrideH_, ConvDilationH_); + GcdStrideDilationW_ = math::gcd(ConvStrideW_, ConvDilationW_); + + YTilde_ = ConvStrideH_ / GcdStrideDilationH_; + XTilde_ = ConvStrideW_ / GcdStrideDilationW_; + + HTilde_ = Ho_ + math::integer_divide_ceil(ConvDilationH_ * (Y_ - I1), ConvStrideH_); + WTilde_ = Wo_ + math::integer_divide_ceil(ConvDilationW_ * (X_ - I1), ConvStrideW_); + + YDot_ = math::integer_divide_ceil(Y_, YTilde_); + XDot_ = math::integer_divide_ceil(X_, XTilde_); + } + +#if 0 // At now not supported to split tensor + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorC_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorA_; + + bool is_a_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = (Do_ / 2) * DoStride_; + c_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = (Ho_ / 2) * HoStride_; + c_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = (Wo_ / 2) * WoStride_; + c_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + + // Calculate start position in input for right tensor + const IndexType do_right_transformer_start_idx = math::integer_divide_ceil((Di_ / 2) + InLeftPadD_ - ((Z_ - 1) * ConvDilationD_), ConvStrideD_); + const IndexType ho_right_transformer_start_idx = math::integer_divide_ceil((Hi_ / 2) + InLeftPadH_ - ((Y_ - 1) * ConvDilationH_), ConvStrideH_); + const IndexType wo_right_transformer_start_idx = math::integer_divide_ceil((Wi_ / 2) + InLeftPadW_ - ((X_ - 1) * ConvDilationW_), ConvStrideW_); + // Calculate last position in input for left tensor + const IndexType do_left_transformer_end_idx = math::integer_divide_ceil((Di_ / 2 - 1) + InLeftPadD_, ConvStrideD_); + const IndexType ho_left_transformer_end_idx = math::integer_divide_ceil((Hi_ / 2 - 1) + InLeftPadH_, ConvStrideH_); + const IndexType wo_left_transformer_end_idx = math::integer_divide_ceil((Wi_ / 2 - 1) + InLeftPadW_, ConvStrideW_); + + + if(Di_!=1) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Di_ = Di_ / 2; + conv_to_gemm_transformer_right.Di_ = Di_ - Di_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Do_ = do_left_transformer_end_idx; + conv_to_gemm_transformer_right.Do_ = Do_ - do_right_transformer_start_idx; + ; + // Calcualte offsets + a_right_offset = do_right_transformer_start_idx * DoStride_; + c_right_offset = (Di_ / 2) * DiStride_; + } + else if(Hi_!=1) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Hi_ = Hi_ / 2; + conv_to_gemm_transformer_right.Hi_ = Hi_ - Hi_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + // // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + // Calculate new input size + conv_to_gemm_transformer_left.Ho_ = ho_left_transformer_end_idx ; + conv_to_gemm_transformer_right.Ho_ = Ho_ - ho_right_transformer_start_idx ; + ; + // Calcualte offsets + a_right_offset = ho_right_transformer_start_idx * HoStride_; + c_right_offset = (Hi_ / 2) * HiStride_; + } + else if(Wi_!=1) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Wi_ = Wi_ / 2; + conv_to_gemm_transformer_right.Wi_ = Wi_ - Wi_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + // Calculate new input size + conv_to_gemm_transformer_left.Wo_ = wo_left_transformer_end_idx; + conv_to_gemm_transformer_right.Wo_ = Wo_ - wo_right_transformer_start_idx; + ; + // Calcualte offsets + a_right_offset = wo_right_transformer_start_idx * WoStride_; + c_right_offset = (Wi_ / 2) * WiStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } +#endif + + __host__ __device__ auto MakeOutGridDesc() const + { + if constexpr(is_same_v) + { + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStrideTensorA_, HoStride_, WoStride_, KStrideTensorA_)); + } + } + else if constexpr(is_same_v) + { + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(WoStride_, KStrideTensorA_)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple(NStrideTensorA_, DoStride_, HoStride_, WoStride_, KStrideTensorA_)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N_ * Ho_ * Wo_, K_)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N_, Ho_, Wo_, K_)); + } + } + else if constexpr(is_same_v) + { + // assume packed + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor_packed(make_tuple(N_ * Do_ * Ho_ * Wo_, K_)); + } + else + { + return make_naive_tensor_descriptor_packed(make_tuple(N_, Do_, Ho_, Wo_, K_)); + } + } + else if constexpr(is_same_v) + { + // assume packed + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + const auto out_gemm_raw_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_ * Wo_, K_), make_tuple(NStrideTensorA_, I1, KStrideTensorA_)); + + return transform_tensor_descriptor( + out_gemm_raw_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(is_same_v) + { + // assume packed + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + const auto out_gemm_raw_grid_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_ * Ho_ * Wo_, K_), + make_tuple(NStrideTensorA_, I1, KStrideTensorA_)); + + return transform_tensor_descriptor( + out_gemm_raw_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + ALayout::name()); + } + } + + __host__ __device__ auto MakeWeiGridDesc() const + { + + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, Y_, X_, C_)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, Z_, Y_, X_, C_)); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + BLayout::name()); + } + } + + __host__ __device__ auto MakeInGridDesc() const + { + + if constexpr(is_same_v || + is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorC_, HiStride_, WiStride_, CStrideTensorC_)); + } + else if constexpr(is_same_v || + is_same_v) + { + return make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorC_, DiStride_, HiStride_, WiStride_, CStrideTensorC_)); + } + else + { + throw std::runtime_error("wrong! unsupported layout: " + CLayout::name()); + } + } + + template < + typename ALayout_ = ALayout, + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeADescriptor_AK0_M_AK1() const + { + // n_do_ho_wo_k for 3d or n_ho_wo_k for 2d + const auto out_grid_desc = MakeOutGridDesc(); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock; + + // A: output tensor + const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_), + make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmak0_gemmmraw_gemmak1_grid_desc, + make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1), + Sequence{}); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + // only work on HTilde and WTilde that contribute to non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IDTildeSliceEnd = math::min( + DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); + const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + if constexpr(NDimSpatial == 2) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + // A: output tensor + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc = + transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform( + make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform( + make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{})); + + const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor( + out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform( + make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk_gemmm_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + out_gemmk_gemmmraw_grid_desc, + make_tuple(GemmKPerBlock, GemmMPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / AK1; + const index_t AK0 = + math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0), + AK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor( + out_gemmk_gemmm_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)), + make_pass_through_transform( + out_gemmk_gemmm_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return out_gemmak0_gemmm_gemmak1_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + } + } + + template < + typename BLayout_ = BLayout, + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeBDescriptor_BK0_N_BK1() const + { + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock; + + // B: weight tensor + const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K_, C_)), + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1)); + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmbk0_gemmnraw_gemmbk1_grid_desc, + make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1), + Sequence{}); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + else + { + // assume packed + // k_y_x_c for 2d or k_z_y_x_c for 3d + static_assert(is_same_v || + is_same_v); + const auto wei_grid_desc = MakeWeiGridDesc(); + + // GemmK is different for each GEMM + const auto ZDotSlice = math::integer_divide_ceil(Z_ - IdxZTilde_, ZTilde_); + const auto YDotSlice = math::integer_divide_ceil(Y_ - IdxYTilde_, YTilde_); + const auto XDotSlice = math::integer_divide_ceil(X_ - IdxXTilde_, XTilde_); + + // B weight tensor + if constexpr(NDimSpatial == 2) + { + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple( + make_pass_through_transform(K_), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<3>{})); + + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemmn_gemmbk1_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc = + transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform( + make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform( + make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform( + make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor( + wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<2>{}, + Sequence<4>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<>{}, + Sequence<4>{})); + + const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor( + wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto wei_gemmk_gemmn_padded_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + wei_gemmk_gemmnraw_grid_desc, + make_tuple(GemmKPerBlock, GemmNPerBlock), + Sequence{}); + + const index_t K0PerBlock = GemmKPerBlock / BK1; + const index_t BK0 = + math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0), + BK1 * K0PerBlock * batch_k_) * + K0PerBlock; + + const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor( + wei_gemmk_gemmn_padded_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)), + make_pass_through_transform( + wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return wei_gemmbk0_gemm_gemmbk1_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + } + } + + template < + typename CLayout_ = CLayout, + typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) && + (is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(CTranspose == false); + // assume strided + // n_hi_wi_c for 2d n_di_hi_wi_c for 3d + const auto in_grid_desc = MakeInGridDesc(); + + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + // C: input tensor + if constexpr(NDimSpatial == 2) + { + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else if constexpr(NDimSpatial == 3) + { + + // C: input tensor + const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(I1, Do_), make_tuple(I1, ConvStrideD_)), + make_embed_transform(make_tuple(I1, Ho_), make_tuple(I1, ConvStrideH_)), + make_embed_transform(make_tuple(I1, Wo_), make_tuple(I1, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_x_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<1>{}, + Sequence<3>{}, + Sequence<5>{}, + Sequence<0, 2, 4, 6>{}, + Sequence<7>{}), + make_tuple( + Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + } + else + { + // only work on DTilde, HTilde and WTilde that contribute to + // non-padding area of input tensor + const auto IDTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadD_ - ConvDilationD_ * (ZTilde_ - I1)), ConvStrideD_); + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IDTildeSliceEnd = math::min( + DTilde_, math::integer_divide_ceil(InLeftPadD_ + Di_ - I1, ConvStrideD_) + I1); + const auto IHTildeSliceEnd = math::min( + HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin; + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // C: input tensor + if constexpr(NDimSpatial == 2) + { + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<3>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + else if(NDimSpatial == 3) + { + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc = + transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc = + transform_tensor_descriptor( + in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}, + Sequence<6>{}, + Sequence<7>{}), + make_tuple(Sequence<0>{}, + Sequence<>{}, + Sequence<1>{}, + Sequence<>{}, + Sequence<2>{}, + Sequence<>{}, + Sequence<3>{}, + Sequence<4>{})); + + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmm_gemmn_grid_desc = + ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + return in_gemmm_gemmn_grid_desc; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + } + } + + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + const auto in_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_, C_, Di_ * Hi_ * Wi_), make_tuple(NStrideTensorC_, CStrideTensorC_, I1)); + + static_assert(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0); + + if constexpr(CTranspose) + { + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(C_), + make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmNPerBlock, GemmMPerBlock), + Sequence{}); + } + else + { + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, Di_ * Hi_ * Wi_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + } + } + // for input bias + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + if constexpr(ConvBwdDataSpecialization == + ck::tensor_operation::device::ConvolutionBackwardDataSpecialization:: + Filter1x1Stride1Pad0) + { + if constexpr(CTranspose) + { + const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(C_, N_ * Ho_ * Wo_), make_tuple(I1, I0)); + + return in_gemmm_gemmn_grid_desc; + } + else + { + const auto in_gemmm_gemmn_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_ * Ho_ * Wo_, C_), make_tuple(I0, I1)); + + return in_gemmm_gemmn_grid_desc; + } + } + else + { + static_assert(CTranspose == false); + // only work on HTilde and WTilde that contribute to non-padding area of input + // tensor + const auto IHTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadH_ - ConvDilationH_ * (YTilde_ - I1)), ConvStrideH_); + const auto IWTildeSliceBegin = math::integer_divide_floor( + math::max(I0, InLeftPadW_ - ConvDilationW_ * (XTilde_ - I1)), ConvStrideW_); + + const auto IHTildeSliceEnd = math::min( + HTilde_, math::integer_divide_ceil(InLeftPadH_ + Hi_ - I1, ConvStrideH_) + I1); + const auto IWTildeSliceEnd = math::min( + WTilde_, math::integer_divide_ceil(InLeftPadW_ + Wi_ - I1, ConvStrideW_) + I1); + + const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin; + const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin; + + // bias tensor + const auto in_gemmmraw_gemmnraw_grid_desc = make_naive_tensor_descriptor( + make_tuple(N_ * HTildeSlice * WTildeSlice, C_), make_tuple(I0, I1)); + + const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor( + in_gemmmraw_gemmnraw_grid_desc, + make_tuple(GemmMPerBlock, GemmNPerBlock), + Sequence{}); + + return in_gemmm_gemmn_grid_desc; + } + } + + IndexType N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType DiStride_, HiStride_, WiStride_; + IndexType DoStride_, HoStride_, WoStride_; + IndexType CStrideTensorB_, CStrideTensorC_, KStrideTensorA_, KStrideTensorB_; + IndexType NStrideTensorA_, NStrideTensorC_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType IdxZTilde_, IdxYTilde_, IdxXTilde_; + IndexType GcdStrideDilationD_, GcdStrideDilationH_, GcdStrideDilationW_; + IndexType ZTilde_, YTilde_, XTilde_; + IndexType DTilde_, HTilde_, WTilde_; + IndexType ZDot_, YDot_, XDot_; + index_t batch_k_; +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp new file mode 100755 index 000000000000..bd3ab1080247 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp @@ -0,0 +1,714 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +template +struct TransformConvBwdWeightToGemm +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C), + make_tuple(NStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride)); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K), + make_tuple(WoStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C), + make_tuple(WiStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, C), + make_tuple(NStride, DiStride, HiStride, WiStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C), + make_tuple(KStride, CStride)); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& /* input_strides */, + const std::array& /* weights_strides */, + const std::array& /* output_strides */, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[0]; + const index_t ConvStrideW = conv_filter_strides[0]; + const index_t ConvDilationW = conv_filter_dilations[0]; + const index_t InLeftPadW = input_left_pads[0]; + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X; + + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wi, C)); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); + } + else + { + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Wo, K)); + const auto in_n_wi_c_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C)); + + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_n_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, X * C)); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * X * Y; + + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K; + const index_t GemmN = C * Z * X * Y; + + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch), + make_pass_through_transform(GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } // function end +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp new file mode 100755 index 000000000000..b72ddb824379 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp @@ -0,0 +1,909 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +/** + * @brief Transform conv bwd weight to gemm v2 + * + * This version does following things: + * 1. Merge KBatch with K0 to align descriptor with universal gemm + * 2. Merge Batch with M and N dimension. It allows to increase compute in + * case of small M and N. It also allows to vector load and store in case of + * K = 1, C = 1 and NHWGC layout. + */ +template +struct TransformConvBwdWeightToGemmV2 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[3]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t WiStride = input_strides[3]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Wi, NumGroupsToMerge, C), + make_tuple(NStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[3]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder + // for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[4]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t HiStride = input_strides[3]; + const index_t WiStride = input_strides[4]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[4]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder + // for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Y * X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + constexpr static auto + make_out_grid_desc(const index_t N, + const index_t Do, + const index_t Ho, + const index_t Wo, + const index_t K, + const std::array& output_strides) + { + const index_t BatchStride = output_strides[0]; + const index_t WoStride = output_strides[5]; + const auto KStride = Number<1>{}; + return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K), + make_tuple(WoStride, BatchStride, KStride)); + } + + template ::type = false> + constexpr static auto + make_in_grid_desc(const index_t N, + const index_t Di, + const index_t Hi, + const index_t Wi, + const index_t C, + const std::array& input_strides) + { + const index_t BatchStride = input_strides[0]; + const index_t NStride = input_strides[1]; + const index_t DiStride = input_strides[3]; + const index_t HiStride = input_strides[4]; + const index_t WiStride = input_strides[5]; + const auto CStride = input_strides[2]; + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumGroupsToMerge, C), + make_tuple(WiStride, BatchStride, CStride)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C), + make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride)); + } + } + + template ::type = false> + constexpr static auto + make_wei_grid_desc(const index_t K, + const index_t Z, + const index_t Y, + const index_t X, + const index_t C, + const std::array& weights_strides) + { + const auto CStride = Number<1>{}; + const auto KStride = weights_strides[1]; + const auto XStride = weights_strides[5]; + const auto BatchStride = weights_strides[0]; + // Add NumGroupsToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension + const auto desc = make_naive_tensor_descriptor( + make_tuple(NumGroupsToMerge, K, Z * Y * X, 1, C), + make_tuple(BatchStride, KStride, XStride, BatchStride, CStride)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + desc, + make_tuple(make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pad_transform(1, 0, NumGroupsToMerge - 1), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + // We need only matrices from diagonal. Xor returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 || + NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K), + make_pass_through_transform(Z * Y * X), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)), + make_merge_transform(make_tuple(Z * Y * X, NumGroupsToMerge, C))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Wi = input_spatial_lengths[0]; + + const index_t Wo = output_spatial_lengths[0]; + + const index_t X = filter_spatial_lengths[0]; + + const index_t ConvStrideW = conv_filter_strides[0]; + + const index_t ConvDilationW = conv_filter_dilations[0]; + + const index_t InLeftPadW = input_left_pads[0]; + + const index_t InRightPadW = input_right_pads[0]; + + const index_t GemmKTotal = N * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * X * NumGroupsToMerge; + + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Wo))), + make_tuple(Sequence<1, 3, 4>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + + } // function end + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Hi = input_spatial_lengths[0]; + const index_t Wi = input_spatial_lengths[1]; + + const index_t Ho = output_spatial_lengths[0]; + const index_t Wo = output_spatial_lengths[1]; + + const index_t Y = filter_spatial_lengths[0]; + const index_t X = filter_spatial_lengths[1]; + + const index_t ConvStrideH = conv_filter_strides[0]; + const index_t ConvStrideW = conv_filter_strides[1]; + + const index_t ConvDilationH = conv_filter_dilations[0]; + const index_t ConvDilationW = conv_filter_dilations[1]; + + const index_t InLeftPadH = input_left_pads[0]; + const index_t InLeftPadW = input_left_pads[1]; + + const index_t InRightPadH = input_right_pads[0]; + const index_t InRightPadW = input_right_pads[1]; + + const index_t GemmKTotal = N * Ho * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * X * Y * NumGroupsToMerge; + + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } + + template ::type = false> + static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N( + const index_t N, + const index_t K, + const index_t C, + const std::array& input_spatial_lengths, + const std::array& filter_spatial_lengths, + const std::array& output_spatial_lengths, + const std::array& input_strides, + const std::array& weights_strides, + const std::array& output_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const index_t batch_k) + { + using namespace ck; + + const index_t Di = input_spatial_lengths[0]; + const index_t Hi = input_spatial_lengths[1]; + const index_t Wi = input_spatial_lengths[2]; + + const index_t Do = output_spatial_lengths[0]; + const index_t Ho = output_spatial_lengths[1]; + const index_t Wo = output_spatial_lengths[2]; + + const index_t Z = filter_spatial_lengths[0]; + const index_t Y = filter_spatial_lengths[1]; + const index_t X = filter_spatial_lengths[2]; + + const index_t ConvStrideD = conv_filter_strides[0]; + const index_t ConvStrideH = conv_filter_strides[1]; + const index_t ConvStrideW = conv_filter_strides[2]; + + const index_t ConvDilationD = conv_filter_dilations[0]; + const index_t ConvDilationH = conv_filter_dilations[1]; + const index_t ConvDilationW = conv_filter_dilations[2]; + + const index_t InLeftPadD = input_left_pads[0]; + const index_t InLeftPadH = input_left_pads[1]; + const index_t InLeftPadW = input_left_pads[2]; + + const index_t InRightPadD = input_right_pads[0]; + const index_t InRightPadH = input_right_pads[1]; + const index_t InRightPadW = input_right_pads[2]; + + const index_t GemmKTotal = N * Do * Ho * Wo; + const index_t GemmM = K * NumGroupsToMerge; + const index_t GemmN = C * Z * X * Y * NumGroupsToMerge; + + const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock; + const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock; + + const index_t GemmKBatch = batch_k; + const index_t GemmK0 = + math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) * + K0PerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number; + + const auto out_grid_desc = make_out_grid_desc(N, Do, Ho, Wo, K, output_strides); + const auto in_grid_desc = make_in_grid_desc(N, Di, Hi, Wi, C, input_strides); + const auto wei_grid_desc = make_wei_grid_desc(K, Z, Y, X, C, weights_strides); + + if constexpr(ConvBackwardWeightSpecialization == + device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0) + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_grid_desc); + } + else + { + // A: output tensor + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple( + make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: input tensor + const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Di, InLeftPadD, InRightPadD), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_dip_hip_wip_c_grid_desc, + make_tuple( + make_pass_through_transform(N), + make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumGroupsToMerge, C)), + make_merge_transform(make_tuple(N, Do, Ho, Wo))), + make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // Padd + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmM, PadGemmM), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc = + transform_tensor_descriptor( + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + make_tuple(make_pass_through_transform(GemmKBatch * GemmK0), + make_right_pad_transform(GemmN, PadGemmN), + make_pass_through_transform(GemmK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto wei_gemmm_gemmn_pad_grid_desc = + transform_tensor_descriptor(wei_grid_desc, + make_tuple(make_right_pad_transform(GemmM, PadGemmM), + make_right_pad_transform(GemmN, PadGemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc, + wei_gemmm_gemmn_pad_grid_desc); + } + } // function end +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp new file mode 100755 index 000000000000..92b48c44b332 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp @@ -0,0 +1,1787 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + +namespace ck { +namespace tensor_operation { + +template +struct TransformConvFwdToGemm +{ + private: + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + template + static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, + const ConvDimsType& strides, + index_t i) + { + long_index_t acc = 1; + for(; i < (NDimSpatial + 3); i++) + { + acc += + static_cast(lengths[i] - I1) * static_cast(strides[i]); + } + + return acc; + } + + template + static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides) + { + const long_index_t a_element_space_size = + calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); + const long_index_t c_element_space_size = + calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); + const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), + c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const IndexType N = a_g_n_c_wis_lengths[I1]; + + if(element_space_size > TwoGB) + { + // Minimum divisor of N to not exceed 2GB + const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + + if(divisor <= static_cast(N)) + { + // Find least divisor of N larger than element_space_size / TwoGB + // Iterate up to sqrt(N). There are no divisors above this value. + for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N; + least_divisor++) + { + if(N % least_divisor == 0) + { + return N / least_divisor; + } + } + // Not found, process one Convolution N per block + return 1; + } + else + { + // Split Convolution's N dimension into N workgroups. However + // this still might not result in sufficiently small tensor, + // but at least later on we could divide the image as well. + return 1; + } + } + else + { + // Split N is not needed. + return N; + } + } + + public: + __host__ __device__ constexpr TransformConvFwdToGemm() {} + + template + __host__ __device__ + TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) + : N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, + Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, + Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, + Do_{static_cast(transform_conv_fwd_to_gemm_base.Do_)}, + Ho_{static_cast(transform_conv_fwd_to_gemm_base.Ho_)}, + Wo_{static_cast(transform_conv_fwd_to_gemm_base.Wo_)}, + Z_{static_cast(transform_conv_fwd_to_gemm_base.Z_)}, + Y_{static_cast(transform_conv_fwd_to_gemm_base.Y_)}, + X_{static_cast(transform_conv_fwd_to_gemm_base.X_)}, + K_{static_cast(transform_conv_fwd_to_gemm_base.K_)}, + C_{static_cast(transform_conv_fwd_to_gemm_base.C_)}, + DiStride_{static_cast(transform_conv_fwd_to_gemm_base.DiStride_)}, + HiStride_{static_cast(transform_conv_fwd_to_gemm_base.HiStride_)}, + WiStride_{static_cast(transform_conv_fwd_to_gemm_base.WiStride_)}, + DoStride_{static_cast(transform_conv_fwd_to_gemm_base.DoStride_)}, + HoStride_{static_cast(transform_conv_fwd_to_gemm_base.HoStride_)}, + WoStride_{static_cast(transform_conv_fwd_to_gemm_base.WoStride_)}, + XStride_{static_cast(transform_conv_fwd_to_gemm_base.XStride_)}, + CStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.CStrideTensorA_)}, + CStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.CStrideTensorB_)}, + KStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.KStrideTensorB_)}, + KStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.KStrideTensorC_)}, + NStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.NStrideTensorA_)}, + NStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.NStrideTensorC_)}, + GStrideTensorA_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorA_)}, + GStrideTensorB_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorB_)}, + GStrideTensorC_{static_cast(transform_conv_fwd_to_gemm_base.GStrideTensorC_)}, + ConvStrideD_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideD_)}, + ConvStrideH_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideH_)}, + ConvStrideW_{static_cast(transform_conv_fwd_to_gemm_base.ConvStrideW_)}, + ConvDilationD_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationD_)}, + ConvDilationH_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationH_)}, + ConvDilationW_{static_cast(transform_conv_fwd_to_gemm_base.ConvDilationW_)}, + InLeftPadD_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadD_)}, + InLeftPadH_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadH_)}, + InLeftPadW_{static_cast(transform_conv_fwd_to_gemm_base.InLeftPadW_)}, + InRightPadD_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadD_)}, + InRightPadH_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadH_)}, + InRightPadW_{static_cast(transform_conv_fwd_to_gemm_base.InRightPadW_)}, + ZYX_{static_cast(transform_conv_fwd_to_gemm_base.ZYX_)} + { + } + + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{I1}, + Hi_{I1}, + Wi_{a_g_n_c_wis_lengths[I3]}, + Do_{I1}, + Ho_{I1}, + Wo_{c_g_n_k_wos_lengths[I3]}, + Z_{I1}, + Y_{I1}, + X_{b_g_k_c_xs_lengths[I3]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{I1}, + HiStride_{I1}, + WiStride_{a_g_n_c_wis_strides[I3]}, + DoStride_{I1}, + HoStride_{I1}, + WoStride_{c_g_n_k_wos_strides[I3]}, + XStride_{b_g_k_c_xs_strides[I3]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + NStrideTensorC_{c_g_n_k_wos_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{I1}, + ConvStrideH_{I1}, + ConvStrideW_{conv_filter_strides[I0]}, + ConvDilationD_{I1}, + ConvDilationH_{I1}, + ConvDilationW_{conv_filter_dilations[I0]}, + InLeftPadD_{I0}, + InLeftPadH_{I0}, + InLeftPadW_{input_left_pads[I0]}, + InRightPadD_{I0}, + InRightPadH_{I0}, + InRightPadW_{input_right_pads[I0]}, + ZYX_{X_} + { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); +#endif + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } + } + + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{I1}, + Hi_{a_g_n_c_wis_lengths[I3]}, + Wi_{a_g_n_c_wis_lengths[I4]}, + Do_{I1}, + Ho_{c_g_n_k_wos_lengths[I3]}, + Wo_{c_g_n_k_wos_lengths[I4]}, + Z_{I1}, + Y_{b_g_k_c_xs_lengths[I3]}, + X_{b_g_k_c_xs_lengths[I4]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{I1}, + HiStride_{a_g_n_c_wis_strides[I3]}, + WiStride_{a_g_n_c_wis_strides[I4]}, + DoStride_{I1}, + HoStride_{c_g_n_k_wos_strides[I3]}, + WoStride_{c_g_n_k_wos_strides[I4]}, + XStride_{b_g_k_c_xs_strides[I4]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + NStrideTensorC_{c_g_n_k_wos_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{I1}, + ConvStrideH_{conv_filter_strides[I0]}, + ConvStrideW_{conv_filter_strides[I1]}, + ConvDilationD_{I1}, + ConvDilationH_{conv_filter_dilations[I0]}, + ConvDilationW_{conv_filter_dilations[I1]}, + InLeftPadD_{I0}, + InLeftPadH_{input_left_pads[I0]}, + InLeftPadW_{input_left_pads[I1]}, + InRightPadD_{I0}, + InRightPadH_{input_right_pads[I0]}, + InRightPadW_{input_right_pads[I1]}, + ZYX_{Y_ * X_} + { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); +#endif + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } + } + + template ::type = false> + __host__ __device__ TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths, + const ConvDimsType& a_g_n_c_wis_strides, + const ConvDimsType& b_g_k_c_xs_lengths, + const ConvDimsType& b_g_k_c_xs_strides, + const ConvDimsType& c_g_n_k_wos_lengths, + const ConvDimsType& c_g_n_k_wos_strides, + const ConvSpatialDimsType& conv_filter_strides, + const ConvSpatialDimsType& conv_filter_dilations, + const ConvSpatialDimsType& input_left_pads, + const ConvSpatialDimsType& input_right_pads) + : Di_{a_g_n_c_wis_lengths[I3]}, + Hi_{a_g_n_c_wis_lengths[I4]}, + Wi_{a_g_n_c_wis_lengths[I5]}, + Do_{c_g_n_k_wos_lengths[I3]}, + Ho_{c_g_n_k_wos_lengths[I4]}, + Wo_{c_g_n_k_wos_lengths[I5]}, + Z_{b_g_k_c_xs_lengths[I3]}, + Y_{b_g_k_c_xs_lengths[I4]}, + X_{b_g_k_c_xs_lengths[I5]}, + K_{c_g_n_k_wos_lengths[I2]}, + C_{b_g_k_c_xs_lengths[I2]}, + DiStride_{a_g_n_c_wis_strides[I3]}, + HiStride_{a_g_n_c_wis_strides[I4]}, + WiStride_{a_g_n_c_wis_strides[I5]}, + DoStride_{c_g_n_k_wos_strides[I3]}, + HoStride_{c_g_n_k_wos_strides[I4]}, + WoStride_{c_g_n_k_wos_strides[I5]}, + XStride_{b_g_k_c_xs_strides[I5]}, + CStrideTensorA_{a_g_n_c_wis_strides[I2]}, + CStrideTensorB_{b_g_k_c_xs_strides[I2]}, + KStrideTensorB_{b_g_k_c_xs_strides[I1]}, + KStrideTensorC_{c_g_n_k_wos_strides[I2]}, + NStrideTensorA_{a_g_n_c_wis_strides[I1]}, + NStrideTensorC_{c_g_n_k_wos_strides[I1]}, + GStrideTensorA_{a_g_n_c_wis_strides[I0]}, + GStrideTensorB_{b_g_k_c_xs_strides[I0]}, + GStrideTensorC_{c_g_n_k_wos_strides[I0]}, + ConvStrideD_{conv_filter_strides[I0]}, + ConvStrideH_{conv_filter_strides[I1]}, + ConvStrideW_{conv_filter_strides[I2]}, + ConvDilationD_{conv_filter_dilations[I0]}, + ConvDilationH_{conv_filter_dilations[I1]}, + ConvDilationW_{conv_filter_dilations[I2]}, + InLeftPadD_{input_left_pads[I0]}, + InLeftPadH_{input_left_pads[I1]}, + InLeftPadW_{input_left_pads[I2]}, + InRightPadD_{input_right_pads[I0]}, + InRightPadH_{input_right_pads[I1]}, + InRightPadW_{input_right_pads[I2]}, + ZYX_{Z_ * Y_ * X_} + { +#ifdef CK_CODE_GEN_RTC + static_assert(is_same_v>); + static_assert(is_same_v>); +#else + static_assert(is_same_v> || + is_same_v>); + static_assert(is_same_v> || + is_same_v>); +#endif + if constexpr(SplitN) + { + N_ = GetSplitedNSize( + a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + } + else + { + N_ = c_g_n_k_wos_lengths[I1]; + } + } + + __host__ bool AreDescriptorsSmallerThan2GB() const + { + constexpr long_index_t TwoGB = (long_index_t{1} << 31); + + const long_index_t in_desc_space_size = + I1 + (N_ - I1) * NStrideTensorA_ + (Di_ - I1) * DiStride_ + (Hi_ - I1) * HiStride_ + + (Wi_ - I1) * WiStride_ + (C_ - I1) * CStrideTensorA_; + const long_index_t out_desc_space_size = + I1 + (N_ - I1) * NStrideTensorC_ + (Do_ - I1) * DoStride_ + (Ho_ - I1) * HoStride_ + + (Wo_ - I1) * WoStride_ + (K_ - I1) * KStrideTensorC_; + + bool is_a_descriptor_smaller_than_2GB = (in_desc_space_size * sizeof(ADataType)) <= TwoGB; + bool is_c_descriptor_smaller_than_2GB = (out_desc_space_size * sizeof(CDataType)) <= TwoGB; + + return is_a_descriptor_smaller_than_2GB && is_c_descriptor_smaller_than_2GB; + } + + __host__ auto SplitConvProblem(const ADataType* a_grid_ptr_base, + CDataType* c_grid_ptr_base) const + { + // Create copies + auto conv_to_gemm_transformer_left = *this; + auto conv_to_gemm_transformer_right = *this; + IndexType a_right_offset = 0; + IndexType c_right_offset = 0; + // Calculate real filter size + const IndexType z_eff = (Z_ - 1) * ConvDilationD_ + 1; + const IndexType y_eff = (Y_ - 1) * ConvDilationH_ + 1; + const IndexType x_eff = (X_ - 1) * ConvDilationW_ + 1; + // Calculate start position in input for right tensor + const IndexType di_right_transformer_start_idx = (Do_ / 2) * ConvStrideD_; + const IndexType hi_right_transformer_start_idx = (Ho_ / 2) * ConvStrideH_; + const IndexType wi_right_transformer_start_idx = (Wo_ / 2) * ConvStrideW_; + // Calculate last position in input for left tensor + const IndexType di_left_transformer_end_idx = (Do_ / 2 - 1) * ConvStrideD_ + z_eff; + const IndexType hi_left_transformer_end_idx = (Ho_ / 2 - 1) * ConvStrideH_ + y_eff; + const IndexType wi_left_transformer_end_idx = (Wo_ / 2 - 1) * ConvStrideW_ + x_eff; + // Allow to split if whole left padding will be in left tensor and right padding in right + // tensor + const bool is_possible_to_split_d = Do_ != 1 && + di_right_transformer_start_idx > InLeftPadD_ && + di_left_transformer_end_idx <= (InLeftPadD_ + Di_); + const bool is_possible_to_split_h = Ho_ != 1 && + hi_right_transformer_start_idx > InLeftPadH_ && + hi_left_transformer_end_idx <= (InLeftPadH_ + Hi_); + const bool is_possible_to_split_w = Wo_ != 1 && + wi_right_transformer_start_idx > InLeftPadW_ && + wi_left_transformer_end_idx <= (InLeftPadW_ + Wi_); + + if(is_possible_to_split_d) + { + // Apply new sizes + // Split output on half + conv_to_gemm_transformer_left.Do_ = Do_ / 2; + conv_to_gemm_transformer_right.Do_ = Do_ - Do_ / 2; + // Assign left padding to left convolution + conv_to_gemm_transformer_left.InLeftPadD_ = InLeftPadD_; + conv_to_gemm_transformer_right.InLeftPadD_ = 0; + // Assign right padding to right convolution + conv_to_gemm_transformer_left.InRightPadD_ = 0; + conv_to_gemm_transformer_right.InRightPadD_ = InRightPadD_; + // Calculate new input size + conv_to_gemm_transformer_left.Di_ = di_left_transformer_end_idx - InLeftPadD_; + conv_to_gemm_transformer_right.Di_ = + math::min(Di_ - (di_right_transformer_start_idx - InLeftPadD_), + (conv_to_gemm_transformer_right.Do_ - 1) * ConvStrideD_ + z_eff); + ; + // Calcualte offsets + a_right_offset = ((Do_ / 2) * ConvStrideD_ - InLeftPadD_) * DiStride_; + c_right_offset = (Do_ / 2) * DoStride_; + } + else if(is_possible_to_split_h) + { + conv_to_gemm_transformer_left.Ho_ = Ho_ / 2; + conv_to_gemm_transformer_right.Ho_ = Ho_ - Ho_ / 2; + + conv_to_gemm_transformer_left.InLeftPadH_ = InLeftPadH_; + conv_to_gemm_transformer_right.InLeftPadH_ = 0; + + conv_to_gemm_transformer_left.InRightPadH_ = 0; + conv_to_gemm_transformer_right.InRightPadH_ = InRightPadH_; + + conv_to_gemm_transformer_left.Hi_ = hi_left_transformer_end_idx - InLeftPadH_; + conv_to_gemm_transformer_right.Hi_ = + math::min(Hi_ - (hi_right_transformer_start_idx - InLeftPadH_), + (conv_to_gemm_transformer_right.Ho_ - 1) * ConvStrideH_ + y_eff); + a_right_offset = ((Ho_ / 2) * ConvStrideH_ - InLeftPadH_) * HiStride_; + c_right_offset = (Ho_ / 2) * HoStride_; + } + else if(is_possible_to_split_w) + { + conv_to_gemm_transformer_left.Wo_ = Wo_ / 2; + conv_to_gemm_transformer_right.Wo_ = Wo_ - Wo_ / 2; + + conv_to_gemm_transformer_left.InLeftPadW_ = InLeftPadW_; + conv_to_gemm_transformer_right.InLeftPadW_ = 0; + + conv_to_gemm_transformer_left.InRightPadW_ = 0; + conv_to_gemm_transformer_right.InRightPadW_ = InRightPadW_; + + conv_to_gemm_transformer_left.Wi_ = wi_left_transformer_end_idx - InLeftPadW_; + conv_to_gemm_transformer_right.Wi_ = + math::min(Wi_ - (wi_right_transformer_start_idx - InLeftPadW_), + (conv_to_gemm_transformer_right.Wo_ - 1) * ConvStrideW_ + x_eff); + + a_right_offset = ((Wo_ / 2) * ConvStrideW_ - InLeftPadW_) * WiStride_; + c_right_offset = (Wo_ / 2) * WoStride_; + } + // Return left transform, right transformer, right offset to Input and right offset to + // Output + return ck::make_tuple(conv_to_gemm_transformer_left, + conv_to_gemm_transformer_right, + a_grid_ptr_base + a_right_offset, + c_grid_ptr_base + c_right_offset); + } + + // TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as + // properties + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + { + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_), make_tuple(NStrideTensorA_, WiStride_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(Number<3>{})), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wo_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, C_), + make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_wip_c_desc = transform_tensor_descriptor( + in_n_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_x_wo_c_desc = transform_tensor_descriptor( + in_n_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(X_, C_))), + make_tuple(Sequence<0, 2, 3>{}, Sequence<1, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + + { + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_), make_tuple(NStrideTensorA_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_ho_wo_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_ho_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + + const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple( + NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_)); + + const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor( + in_n_hi_wi_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor( + in_n_hip_wip_groups_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5>{}, + Sequence<6>{})); + + return transform_tensor_descriptor( + in_n_y_ho_x_wo_groups_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4, 5>{}, Sequence<1, 3, 6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + + { + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_groups_gemmk_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5, 6>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Number<3>{}, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Number<3>{}, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(Number<3>{}, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Number<3>{}, Number<3>{}, Number<3>{}))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0) + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)), + make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)), + make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + return transform_tensor_descriptor( + in_n_do_ho_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1, 2, 3, 4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, C_), + make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4, 6>{}, Sequence<1, 3, 5, 7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor( + make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_), + make_tuple(NStrideTensorA_, + DiStride_, + HiStride_, + WiStride_, + GStrideTensorA_, + CStrideTensorA_)); + + const auto in_n_hip_wip_c_desc = transform_tensor_descriptor( + in_n_di_hi_wi_c_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{})); + + const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor( + in_n_hip_wip_c_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(Z_, Do_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(Y_, Ho_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(X_, Wo_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(C_)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1, 2>{}, + Sequence<3, 4>{}, + Sequence<5, 6>{}, + Sequence<7>{}, + Sequence<8>{})); + + return transform_tensor_descriptor( + in_n_z_do_y_ho_x_wo_c_desc, + make_tuple( + make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)), + make_merge_transform(make_tuple(Z_, Y_, X_, C_))), + make_tuple(Sequence<0, 2, 4, 6, 7>{}, Sequence<1, 3, 5, 8>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template , + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + { + static_assert(NumGroupsToMerge == 1); + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0); + + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template , + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + { + static_assert(NumGroupsToMerge == 1); + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0); + + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template , + bool>::type = false> + __host__ __device__ auto MakeADescriptor_M_K() const + { + static_assert(NumGroupsToMerge == 1); + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0); + + const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(N_, Do_ * Ho_ * Wo_, C_), make_tuple(NStrideTensorA_, I1, CStrideTensorA_)); + + return transform_tensor_descriptor( + in_gemmm_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)), + make_pass_through_transform(C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ auto MakeBDescriptor_N_K() const + { + static_assert(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0 || + ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter1x1Pad0); + static_assert(NumGroupsToMerge == 1); + return make_naive_tensor_descriptor_packed(make_tuple(K_, C_)); + } + + template || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ auto MakeBDescriptor_N_K() const + { + if constexpr(ConvForwardSpecialization == + device::ConvolutionForwardSpecialization::Filter3x3) + { + using FilterSizeNumType = + ck::conditional_t, + ck::conditional_t, Number<27>>>; + + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, FilterSizeNumType{})); + } + else + { + + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(FilterSizeNumType{})), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + else + { + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor_packed(make_tuple(K_, ZYX_ * C_)); + } + else + { + const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor( + make_tuple(K_, NumGroupsToMerge, ZYX_ * C_), + make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_)); + return transform_tensor_descriptor( + wei_gemmn_groups_gemmk_desc, + make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)), + make_pass_through_transform(ZYX_ * C_)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + } + + template < + typename BLayout, + typename ck::enable_if || + is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v, + bool>::type = false> + __host__ __device__ auto MakeBDescriptor_N_K() const + { + const auto wei_k_yx_c_desc = make_naive_tensor_descriptor( + make_tuple(K_, ZYX_, C_), make_tuple(KStrideTensorB_, XStride_, CStrideTensorB_)); + + const auto wei_gemmn_gemmk_desc = transform_tensor_descriptor( + wei_k_yx_c_desc, + make_tuple(make_pass_through_transform(K_), make_merge_transform(make_tuple(ZYX_, C_))), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return wei_gemmn_gemmk_desc; + } + + template < + typename CLayout, + index_t NDimSp = NDimSpatial, + + typename ck::enable_if), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + if constexpr(CTranspose) + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Wo_), + make_tuple(KStrideTensorC_, I0)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } + } + + template < + typename CLayout, + index_t NDimSp = NDimSpatial, + + typename ck::enable_if), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + if constexpr(CTranspose) + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Ho_ * Wo_), + make_tuple(KStrideTensorC_, I0)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } + } + + template < + typename CLayout, + index_t NDimSp = NDimSpatial, + + typename ck::enable_if), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + if constexpr(CTranspose) + { + return make_naive_tensor_descriptor(make_tuple(K_, N_ * Do_ * Ho_ * Wo_), + make_tuple(KStrideTensorC_, I0)); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(I0, KStrideTensorC_)); + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(CTranspose == false); + const IndexType NDoHoWo = N_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor( + make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple( + NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(CTranspose == false); + const IndexType NDoHoWo = N_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + template || + is_same_v || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(CTranspose == false); + const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_; + if constexpr(NumGroupsToMerge == 1) + { + return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_), + make_tuple(WoStride_, KStrideTensorC_)); + } + else + { + const auto nhwo_groups_k_1_desc = + make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1), + make_tuple(NStrideTensorC_, + DoStride_, + HoStride_, + WoStride_, + GStrideTensorC_, + KStrideTensorC_, + GStrideTensorC_)); + // Padd 1 to NumGroupsToMerge + const auto padded_desc = transform_tensor_descriptor( + nhwo_groups_k_1_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)), + make_pass_through_transform(NumGroupsToMerge), + make_pass_through_transform(K_), + make_pad_transform(1, 0, NumGroupsToMerge - 1)), + make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}, Sequence<5>{}, Sequence<6>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + // We need only matrices from diagonal. X_or returns 0 for the same + // values. So if matrices is not on diagonal then it will be stored in padding. + // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2. + static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 || + NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || + NumGroupsToMerge == 32 || NumGroupsToMerge == 64); + const auto unmerged_padded_desc = transform_tensor_descriptor( + padded_desc, + make_tuple(make_pass_through_transform(NDoHoWo), + make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2>{})); + // Merge To M, N + return transform_tensor_descriptor( + unmerged_padded_desc, + make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)), + make_merge_transform(make_tuple(K_, NumGroupsToMerge))), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(NumGroupsToMerge == 1); + auto n_k_wo_desc = make_naive_tensor_descriptor( + make_tuple(N_, K_, Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1)); + if constexpr(CTranspose) + { + return transform_tensor_descriptor( + n_k_wo_desc, + make_tuple(make_pass_through_transform(K_), + make_merge_transform(make_tuple(N_, Wo_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor(n_k_wo_desc, + make_tuple(make_merge_transform(make_tuple(N_, Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(NumGroupsToMerge == 1); + auto n_k_howo_desc = make_naive_tensor_descriptor( + make_tuple(N_, K_, Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1)); + if constexpr(CTranspose) + { + return transform_tensor_descriptor( + n_k_howo_desc, + make_tuple(make_pass_through_transform(K_), + make_merge_transform(make_tuple(N_, Ho_ * Wo_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + n_k_howo_desc, + make_tuple(make_merge_transform(make_tuple(N_, Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + + template || + is_same_v), + bool>::type = false> + __host__ __device__ auto MakeCDescriptor_M_N() const + { + static_assert(NumGroupsToMerge == 1); + auto n_k_dohowo_desc = make_naive_tensor_descriptor( + make_tuple(N_, K_, Do_ * Ho_ * Wo_), make_tuple(NStrideTensorC_, KStrideTensorC_, I1)); + + if constexpr(CTranspose) + { + return transform_tensor_descriptor( + n_k_dohowo_desc, + make_tuple(make_pass_through_transform(K_), + make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + return transform_tensor_descriptor( + n_k_dohowo_desc, + make_tuple(make_merge_transform(make_tuple(N_, Do_ * Ho_ * Wo_)), + make_pass_through_transform(K_)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + } + IndexType N_; + IndexType Di_, Hi_, Wi_; + IndexType Do_, Ho_, Wo_; + IndexType Z_, Y_, X_; + IndexType K_, C_; + IndexType DiStride_, HiStride_, WiStride_; + IndexType DoStride_, HoStride_, WoStride_; + IndexType XStride_; + IndexType CStrideTensorA_, CStrideTensorB_, KStrideTensorB_, KStrideTensorC_; + IndexType NStrideTensorA_, NStrideTensorC_; + IndexType GStrideTensorA_, GStrideTensorB_, GStrideTensorC_; + IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_; + IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_; + IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_; + IndexType InRightPadD_, InRightPadH_, InRightPadW_; + IndexType ZYX_; +}; + +// wrapper class to call member functions on TransformConvToGemm struct at runtime +// TODO: figure out aq way to properly pass in layout as an argument +struct TransformConv +{ + TransformConv() {} + + template + auto + transform_func(TransformConvFwdToGemm conv_fwd_to_gemm) + { + if(NDimSpatial == 2) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(); + } + else if(NDimSpatial == 3) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(); + } + else if(NDimSpatial == 1) + { + return conv_fwd_to_gemm + .template MakeCDescriptor_M_N(); + } + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp new file mode 100755 index 000000000000..0f28fe816930 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp @@ -0,0 +1,466 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" + +namespace ck { +namespace tensor_operation { + +/* + * Transform Convolution NGCHW to NHWGC. We transform [N, G, C, H, W] tensor + * descriptor to [N * G * C, H * W] (input or output image). The first + * dimension is store dimension, the second one is load dimension. For + * NHWGC to NGCHW load and store are reverted. For weight we transform + * [G, K, C, Y, X] to [G * K * Y * X, C]. First dim is load dimension, + * second dim is store dimension. + */ + +template +struct TransformConvNGCHWToNHWGC +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + + template ::type = false> + static auto + MakeNGCHWTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t N = g_n_c_wis_lengths[I1] / split_n_size; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Wi = g_n_c_wis_lengths[I3]; + + const index_t& GStride = g_n_c_wis_strides[I0]; + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t& CStride = g_n_c_wis_strides[I2]; + const index_t& WiStride = g_n_c_wis_strides[I3]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeNHWGCTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t N = g_n_c_wis_lengths[I1] / split_n_size; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Wi = g_n_c_wis_lengths[I3]; + + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t WiStride = G * C; + const index_t GStride = C; + const index_t CStride = 1; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Wi), make_tuple(NStride, GStride, CStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeNGCHWTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t N = g_n_c_wis_lengths[I1] / split_n_size; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Hi = g_n_c_wis_lengths[I3]; + const index_t& Wi = g_n_c_wis_lengths[I4]; + + const index_t& GStride = g_n_c_wis_strides[I0]; + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t& CStride = g_n_c_wis_strides[I2]; + const index_t& HiStride = g_n_c_wis_strides[I3]; + const index_t& WiStride = g_n_c_wis_strides[I4]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeNHWGCTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t N = g_n_c_wis_lengths[I1] / split_n_size; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Hi = g_n_c_wis_lengths[I3]; + const index_t& Wi = g_n_c_wis_lengths[I4]; + + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t HiStride = Wi * G * C; + const index_t WiStride = G * C; + const index_t GStride = C; + const index_t CStride = 1; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Hi, Wi), make_tuple(NStride, GStride, CStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeNGCHWTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t N = g_n_c_wis_lengths[I1] / split_n_size; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Di = g_n_c_wis_lengths[I3]; + const index_t& Hi = g_n_c_wis_lengths[I4]; + const index_t& Wi = g_n_c_wis_lengths[I5]; + + const index_t& GStride = g_n_c_wis_strides[I0]; + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t& CStride = g_n_c_wis_strides[I2]; + const index_t& DiStride = g_n_c_wis_strides[I3]; + const index_t& HiStride = g_n_c_wis_strides[I4]; + const index_t& WiStride = g_n_c_wis_strides[I5]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Di, Hi, Wi), + make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Di, Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeNHWGCTransposeDesc(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides, + const index_t split_n_size = 1) + { + const index_t& G = g_n_c_wis_lengths[I0]; + const index_t N = g_n_c_wis_lengths[I1] / split_n_size; + const index_t& C = g_n_c_wis_lengths[I2]; + const index_t& Di = g_n_c_wis_lengths[I3]; + const index_t& Hi = g_n_c_wis_lengths[I4]; + const index_t& Wi = g_n_c_wis_lengths[I5]; + + const index_t& NStride = g_n_c_wis_strides[I1]; + const index_t DiStride = Hi * Wi * G * C; + const index_t HiStride = Wi * G * C; + const index_t WiStride = G * C; + const index_t GStride = C; + const index_t CStride = 1; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(N, G, C, Di, Hi, Wi), + make_tuple(NStride, GStride, CStride, DiStride, HiStride, WiStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(N, G, C)), + make_merge_transform(make_tuple(Di, Hi, Wi))), + make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKCYXTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& X = g_k_c_wis_lengths[I3]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t& KStride = g_k_c_wis_strides[I1]; + const index_t& CStride = g_k_c_wis_strides[I2]; + const index_t& XStride = g_k_c_wis_strides[I3]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride)); + const auto merged_desc = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKYXCTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& X = g_k_c_wis_lengths[I3]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t KStride = g_k_c_wis_strides[I1]; + const index_t CStride = 1; + const index_t XStride = C; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, X), make_tuple(GStride, KStride, CStride, XStride)); + const auto merged_desc = transform_tensor_descriptor( + desc, + make_tuple(make_merge_transform(make_tuple(G, K, X)), make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKCYXTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Y = g_k_c_wis_lengths[I3]; + const index_t& X = g_k_c_wis_lengths[I4]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t& KStride = g_k_c_wis_strides[I1]; + const index_t& CStride = g_k_c_wis_strides[I2]; + const index_t& YStride = g_k_c_wis_strides[I3]; + const index_t& XStride = g_k_c_wis_strides[I4]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKYXCTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Y = g_k_c_wis_lengths[I3]; + const index_t& X = g_k_c_wis_lengths[I4]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t KStride = g_k_c_wis_strides[I1]; + const index_t CStride = 1; + const index_t YStride = X * C; + const index_t XStride = C; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Y, X), make_tuple(GStride, KStride, CStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKCYXTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Z = g_k_c_wis_lengths[I3]; + const index_t& Y = g_k_c_wis_lengths[I4]; + const index_t& X = g_k_c_wis_lengths[I5]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t& KStride = g_k_c_wis_strides[I1]; + const index_t& CStride = g_k_c_wis_strides[I2]; + const index_t& ZStride = g_k_c_wis_strides[I3]; + const index_t& YStride = g_k_c_wis_strides[I4]; + const index_t& XStride = g_k_c_wis_strides[I5]; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Z, Y, X), + make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + template ::type = false> + static auto + MakeGKYXCTransposeDesc(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + const index_t& G = g_k_c_wis_lengths[I0]; + const index_t& K = g_k_c_wis_lengths[I1]; + const index_t& C = g_k_c_wis_lengths[I2]; + const index_t& Z = g_k_c_wis_lengths[I3]; + const index_t& Y = g_k_c_wis_lengths[I4]; + const index_t& X = g_k_c_wis_lengths[I5]; + + const index_t& GStride = g_k_c_wis_strides[I0]; + const index_t KStride = g_k_c_wis_strides[I1]; + const index_t CStride = 1; + const index_t ZStride = Y * X * C; + const index_t YStride = X * C; + const index_t XStride = C; + + const auto desc = make_naive_tensor_descriptor( + make_tuple(G, K, C, Z, Y, X), + make_tuple(GStride, KStride, CStride, ZStride, YStride, XStride)); + const auto merged_desc = + transform_tensor_descriptor(desc, + make_tuple(make_merge_transform(make_tuple(G, K, Z, Y, X)), + make_pass_through_transform(C)), + make_tuple(Sequence<0, 1, 3, 4, 5>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + return device::PadTensorDescriptor( + merged_desc, make_tuple(MPerThread, NPerThread), Sequence{}); + } + + static auto TransposeInOutStrides(const std::array& g_n_c_wis_lengths, + const std::array& g_n_c_wis_strides) + { + if constexpr(device::is_NGCHW_NGKHW() || + device::is_NGCDHW_NGKDHW()) + { + std::array g_n_c_wis_strides_transposed; + const auto G = g_n_c_wis_lengths[I0]; + const auto C = g_n_c_wis_lengths[I2]; + + g_n_c_wis_strides_transposed[I0] = C; + g_n_c_wis_strides_transposed[I1] = g_n_c_wis_strides[I1]; + g_n_c_wis_strides_transposed[I2] = I1; + if constexpr(NDimSpatial == 2) + { + g_n_c_wis_strides_transposed[I3] = g_n_c_wis_lengths[I4] * G * C; + g_n_c_wis_strides_transposed[I4] = G * C; + } + else if constexpr(NDimSpatial == 3) + { + g_n_c_wis_strides_transposed[I3] = + g_n_c_wis_lengths[I4] * g_n_c_wis_lengths[I5] * G * C; + g_n_c_wis_strides_transposed[I4] = g_n_c_wis_lengths[I5] * G * C; + g_n_c_wis_strides_transposed[I5] = G * C; + } + return g_n_c_wis_strides_transposed; + } + else + { + // transpose not needed + return g_n_c_wis_strides; + } + } + + static auto + TransposeWeiStrides(const std::array& g_k_c_wis_lengths, + const std::array& g_k_c_wis_strides) + { + if constexpr(device::is_NGCHW_GKCYX_NGKHW() || + device::is_NGCDHW_GKCZYX_NGKDHW()) + { + std::array g_k_c_wis_strides_transposed = g_k_c_wis_strides; + const index_t C = g_k_c_wis_lengths[I2]; + + if constexpr(NDimSpatial == 2) + { + const index_t X = g_k_c_wis_lengths[I4]; + g_k_c_wis_strides_transposed[I2] = 1; + g_k_c_wis_strides_transposed[I3] = X * C; + g_k_c_wis_strides_transposed[I4] = C; + } + else if constexpr(NDimSpatial == 3) + { + const index_t Y = g_k_c_wis_lengths[I4]; + const index_t X = g_k_c_wis_lengths[I5]; + g_k_c_wis_strides_transposed[I2] = 1; + g_k_c_wis_strides_transposed[I3] = Y * X * C; + g_k_c_wis_strides_transposed[I4] = X * C; + g_k_c_wis_strides_transposed[I5] = C; + } + return g_k_c_wis_strides_transposed; + } + else + { + // transpose not needed + return g_k_c_wis_strides; + } + } +}; + +} // namespace tensor_operation +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_address_space.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_address_space.hpp new file mode 100755 index 000000000000..d54f70e750e8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_address_space.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "c_style_pointer_cast.hpp" + +// Address Space for AMDGCN +// https://llvm.org/docs/AMDGPUUsage.html#address-space + +namespace ck { + +enum struct AddressSpaceEnum +{ + Generic, + Global, + Lds, + Sgpr, + Vgpr, +}; + +template +__device__ T* cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE* p) +{ + // cast a pointer in "Constant" address space (4) to "Generic" address space (0) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +template +__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE* cast_pointer_to_constant_address_space(T* p) +{ + // cast a pointer in "Generic" address space (0) to "Constant" address space (4) + // only c-style pointer cast seems be able to be compiled +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" + return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast) +#pragma clang diagnostic pop +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_buffer_addressing.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_buffer_addressing.hpp new file mode 100755 index 000000000000..783fc661ceee --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_buffer_addressing.hpp @@ -0,0 +1,1067 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" + +namespace ck { + +template +union BufferResource +{ + __device__ constexpr BufferResource() : content{} {} + + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + StaticallyIndexedArray address; + StaticallyIndexedArray range; + StaticallyIndexedArray config; +}; + +template +__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T); + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +template +__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +// buffer load i8 +__device__ int8_t +llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8"); + +__device__ int8x2_t +llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8"); + +__device__ int8x4_t +llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8"); + +// buffer load i16 +__device__ bhalf_t +llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16"); + +__device__ bhalf2_t +llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16"); + +__device__ bhalf4_t +llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16"); + +// buffer load i32 +__device__ int32_t +llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32"); + +__device__ int32x2_t +llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32"); + +__device__ int32x4_t +llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32"); + +// buffer load fp16 +__device__ half_t +llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16"); + +__device__ half2_t +llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16"); + +__device__ half4_t +llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16"); + +// buffer load fp32 +__device__ float +llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32"); + +__device__ float2_t +llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32"); + +__device__ float4_t +llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32"); + +// buffer store i8 +__device__ void +llvm_amdgcn_raw_buffer_store_i8(int8_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8"); + +// buffer store i16 +__device__ void +llvm_amdgcn_raw_buffer_store_i16(bhalf_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16x2(bhalf2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i16x4(bhalf4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16"); + +// buffer store i32 +__device__ void +llvm_amdgcn_raw_buffer_store_i32(int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32"); + +// buffer store fp16 +__device__ void +llvm_amdgcn_raw_buffer_store_fp16(half_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16"); + +// buffer store fp32 +__device__ void +llvm_amdgcn_raw_buffer_store_fp32(float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32"); + +__device__ void +llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); + +// buffer atomic-add fp16 +__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 +__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// buffer atomic-add fp32 +__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); + +// buffer atomic-add fp32 +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + +// memory coherency bit for buffer store/load instruction +// check ISA manual for each GFX target +// e.g. for +// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, +// page 67~68 +enum struct AmdBufferCoherenceEnum +{ + DefaultCoherence = 0, // default value + GLC = 1, + SLC = 2, + GLC_SLC = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, +}; + +template +__device__ typename vector_type::type +amd_buffer_load_impl_raw(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) + { + return llvm_amdgcn_raw_buffer_load_i8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + + int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 4) + { + int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 8) + { + int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 16) + { + int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 32) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + + return bit_cast(tmp); + } + else if constexpr(N == 64) + { + int32x4_t tmp0 = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp2 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp3 = + llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int32_t), + static_cast(coherence)); + + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + tmp.AsType()(Number<2>{}) = tmp2; + tmp.AsType()(Number<3>{}) = tmp3; + + return bit_cast(tmp); + } +} + +template +__device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + auto raw_data = amd_buffer_load_impl_raw( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); + return bit_cast(raw_data); +} + +template +__device__ void +amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_i8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + + llvm_amdgcn_raw_buffer_store_i16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_i32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + llvm_amdgcn_raw_buffer_store_i32x2(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 16) + { + llvm_amdgcn_raw_buffer_store_i32x4(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 32) + { + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + } + else if constexpr(N == 64) + { + vector_type tmp{bit_cast(src_thread_data)}; + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); + + llvm_amdgcn_raw_buffer_store_i32x4(tmp.template AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); + } +} + +template +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + + amd_buffer_store_impl_raw(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); +} + +template +__device__ void amd_global_atomic_add_impl(const typename vector_type::type src_thread_data, + T* addr) +{ + static_assert((is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, + tmp.template AsType()[i]); + }); + } +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) + else if constexpr(is_same::value) + { + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, + tmp.template AsType()[i]); + }); + } +#endif +} + +template +__device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + +template +__device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + +#else + + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; + return src_thread_element_valid ? tmp : vector_t(0); +#endif +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size, + T customized_value) +{ + const int32x4_t src_wave_buffer_resource = + make_wave_buffer_resource(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; + + return src_thread_element_valid ? tmp : vector_t(customized_value); +} + +// buffer_store requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// buffer_atomic_add requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + + if constexpr(is_same::value) + { + if(dst_thread_element_valid) + { + amd_global_atomic_add_impl( + src_thread_data, p_dst_wave + dst_thread_element_offset); + } + } + else + { +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif + } +} + +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// Direct loads from global to LDS. +__device__ void +llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, + __attribute__((address_space(3))) uint32_t* lds_ptr, + index_t size, + index_t voffset, + index_t soffset, + index_t offset, + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + +#ifndef __HIPCC_RTC__ +template +__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, + const index_t global_offset, + T* lds_base_ptr, + const index_t lds_offset, + const bool is_valid, + const index_t src_element_space_size) +{ + // Direct loads require that each thread reads and writes exactly a single DWORD. + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx942__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes); +#endif + + const int32x4_t src_resource = + make_wave_buffer_resource(global_base_ptr, src_element_space_size); + const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; + +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; +#ifndef CK_CODE_GEN_RTC + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#else + auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#endif + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource) + : "memory"); +#else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = +#ifndef CK_CODE_GEN_RTC + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); +#else + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); +#endif + + llvm_amdgcn_raw_buffer_load_lds( + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); +#endif +} +#endif + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_buffer_addressing_builtins.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_buffer_addressing_builtins.hpp new file mode 100755 index 000000000000..f642e06050e5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_buffer_addressing_builtins.hpp @@ -0,0 +1,888 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" + +namespace ck { + +template +union BufferResource +{ + __device__ constexpr BufferResource() : content{} {} + + // 128 bit SGPRs to supply buffer resource in buffer instructions + // https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions + int32x4_t content; + StaticallyIndexedArray address; + StaticallyIndexedArray range; + StaticallyIndexedArray config; +}; + +template +__device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t element_space_size) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = element_space_size * sizeof(T); + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +template +__device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave) +{ + BufferResource wave_buffer_resource; + + // wavewise base address (64 bit) + wave_buffer_resource.address(Number<0>{}) = const_cast*>(p_wave); + // wavewise range (32 bit) + wave_buffer_resource.range(Number<2>{}) = 0xffffffff; // max possible range + // wavewise setting (32 bit) + wave_buffer_resource.config(Number<3>{}) = CK_BUFFER_RESOURCE_3RD_DWORD; + + return wave_buffer_resource.content; +} + +template +__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T* p_wave, + index_t element_space_size) +{ + // wavewise base address (64 bit) + auto p = const_cast*>(p_wave); + int32_t stride = 0; + int32_t num = element_space_size * sizeof(T); + auto flags = CK_BUFFER_RESOURCE_3RD_DWORD; + + return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags); +} + +template +__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_with_default_range_new(T* p_wave) +{ + // wavewise base address (64 bit) + auto p = const_cast*>(p_wave); + int32_t stride = 0; + int32_t num = 0xffffffff; + auto flags = CK_BUFFER_RESOURCE_3RD_DWORD; + + return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags); +} + +// buffer atomic-add fp16 +__device__ half2_t llvm_amdgcn_raw_buffer_atomic_add_fp16x2( + half2_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.v2f16"); + +// buffer atomic-add i32 +__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// buffer atomic-add fp32 +__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); + +// buffer atomic-add fp32 +__device__ double +llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, + int32x4_t rsrc, // dst_wave_buffer_resource + int voffset, // dst_thread_addr_offset + int soffset, // dst_wave_addr_offset + int glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fmax.f64"); + +// memory coherency bit for buffer store/load instruction +// check ISA manual for each GFX target +// e.g. for +// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf, +// page 67~68 +enum struct AmdBufferCoherenceEnum +{ + DefaultCoherence = 0, // default value + GLC = 1, + SLC = 2, + GLC_SLC = 3, + // gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1 + // SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system + // NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse + WAVE_NT0 = 0, + WAVE_NT1 = 2, + GROUP_NT0 = 1, + GROUP_NT1 = 3, + DEVICE_NT0 = 8, + DEVICE_NT1 = 10, + SYSTEM_NT0 = 9, + SYSTEM_NT1 = 11, +}; + +template +__device__ typename vector_type::type +amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) + { + return __builtin_amdgcn_raw_buffer_load_b8(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + + int16_t tmp = __builtin_amdgcn_raw_buffer_load_b16(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 4) + { + int32_t tmp = __builtin_amdgcn_raw_buffer_load_b32(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 8) + { + int32x2_t tmp = __builtin_amdgcn_raw_buffer_load_b64(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + + return bit_cast(tmp); + } + else if constexpr(N == 16) + { + int32x4_t tmp = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + return bit_cast(tmp); + } + else if constexpr(N == 32) + { + int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + + return bit_cast(tmp); + } + else if constexpr(N == 64) + { + int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset, + static_cast(coherence)); + int32x4_t tmp1 = + __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp2 = + __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 8 * sizeof(int32_t), + static_cast(coherence)); + int32x4_t tmp3 = + __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 12 * sizeof(int32_t), + static_cast(coherence)); + + vector_type tmp; + + tmp.AsType()(Number<0>{}) = tmp0; + tmp.AsType()(Number<1>{}) = tmp1; + tmp.AsType()(Number<2>{}) = tmp2; + tmp.AsType()(Number<3>{}) = tmp3; + + return bit_cast(tmp); + } +} + +template +__device__ typename vector_type::type +amd_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource, + index_t src_thread_addr_offset, + index_t src_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + auto raw_data = amd_buffer_load_impl_raw( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); + return bit_cast(raw_data); +} + +template +__device__ void +amd_buffer_store_impl_raw(const typename vector_type::type src_thread_data, + __amdgpu_buffer_rsrc_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32 || N == 64, + "wrong! not implemented"); + + if constexpr(N == 1) + { + __builtin_amdgcn_raw_buffer_store_b8(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 2) + { + + __builtin_amdgcn_raw_buffer_store_b16(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 4) + { + __builtin_amdgcn_raw_buffer_store_b32(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 8) + { + __builtin_amdgcn_raw_buffer_store_b64(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 16) + { + __builtin_amdgcn_raw_buffer_store_b128(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + } + else if constexpr(N == 32) + { + vector_type tmp{bit_cast(src_thread_data)}; + + __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + } + else if constexpr(N == 64) + { + vector_type tmp{bit_cast(src_thread_data)}; + + __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + static_cast(coherence)); + + __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 4, + static_cast(coherence)); + + __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 8, + static_cast(coherence)); + + __builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t) * 12, + static_cast(coherence)); + } +} + +template +__device__ void amd_buffer_store_impl(const typename vector_type::type src_thread_data, + __amdgpu_buffer_rsrc_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert( + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && + (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), + "wrong! not implemented"); + + using r_t = typename vector_type::type; + + amd_buffer_store_impl_raw(bit_cast(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset); +} + +template +__device__ void amd_global_atomic_add_impl(const typename vector_type::type src_thread_data, + T* addr) +{ + static_assert((is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 2 || N == 4 || N == 8)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast(addr) + i, + tmp.template AsType()[i]); + }); + } +#if defined(__gfx942__) || defined(__gfx950__) || defined(__gfx12__) + else if constexpr(is_same::value) + { + vector_type tmp{src_thread_data}; + static_for<0, N / 2, 1>{}([&](auto i) { + __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast(addr) + i, + tmp.template AsType()[i]); + }); + } +#endif +} + +template +__device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + static_for<0, 2, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + static_for<0, 4, 1>{}([&](auto i) { + llvm_amdgcn_raw_buffer_atomic_add_fp16x2(tmp.AsType()[i], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + i * sizeof(half2_t), + 0); + }); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + +template +__device__ void amd_buffer_atomic_max_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_max_fp64(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(double), + 0); + + llvm_amdgcn_raw_buffer_atomic_max_fp64(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(double), + 0); + } + } +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_zero(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size) +{ + const __amdgpu_buffer_rsrc_t src_wave_buffer_resource = + make_wave_buffer_resource_new(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK + uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000; + return amd_buffer_load_impl( + src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0); + +#else + + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; + return src_thread_element_valid ? tmp : vector_t(0); +#endif +} + +// buffer_load requires: +// 1) p_src_wave must point to global memory space +// 2) p_src_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ typename vector_type_maker::type::type +amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, + index_t src_thread_element_offset, + bool src_thread_element_valid, + index_t src_element_space_size, + T customized_value) +{ + const __amdgpu_buffer_rsrc_t src_wave_buffer_resource = + make_wave_buffer_resource_new(p_src_wave, src_element_space_size); + + index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + + constexpr index_t vector_size = scalar_type::vector_size; + + vector_t tmp{amd_buffer_load_impl( + src_wave_buffer_resource, src_thread_addr_offset, 0)}; + + return src_thread_element_valid ? tmp : vector_t(customized_value); +} + +// buffer_store requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const __amdgpu_buffer_rsrc_t dst_wave_buffer_resource = + make_wave_buffer_resource_new(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_store_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// buffer_atomic_add requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + + if constexpr(is_same::value) + { + if(dst_thread_element_valid) + { + amd_global_atomic_add_impl( + src_thread_data, p_dst_wave + dst_thread_element_offset); + } + } + else + { +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif + } +} + +// buffer_atomic_max requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_max(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_MAX_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000; + + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_max_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + +// Direct loads from global to LDS. +__device__ void +llvm_amdgcn_raw_buffer_load_lds(int32x4_t rsrc, + __attribute__((address_space(3))) uint32_t* lds_ptr, + index_t size, + index_t voffset, + index_t soffset, + index_t offset, + index_t aux) __asm("llvm.amdgcn.raw.buffer.load.lds"); + +#ifndef __HIPCC_RTC__ +template +__device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, + const index_t global_offset, + T* lds_base_ptr, + const index_t lds_offset, + const bool is_valid, + const index_t src_element_space_size) +{ + // Direct loads require that each thread reads and writes a multiple of DWORDs (4 bytes). + // For gfx950: supports 1, 3, or 4 DWORDs per thread + // For gfx942: supports exactly 1 DWORD per thread + constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#if defined(__gfx950__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#elif defined(__gfx942__) + constexpr auto dword_bytes = 4; + static_assert(bytes_per_thread == dword_bytes); +#endif + + const int32x4_t src_resource = + make_wave_buffer_resource(global_base_ptr, src_element_space_size); + const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000; + +#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM + T* lds_ptr = lds_base_ptr + lds_offset; +#ifndef CK_CODE_GEN_RTC + auto const lds_ptr_sgpr = + __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#else + auto const lds_ptr_sgpr = __builtin_amdgcn_readfirstlane((reinterpret_cast(lds_ptr))); +#endif + asm volatile("s_mov_b32 m0, %0; \n\t" + "buffer_load_dword %1, %2, 0 offen lds;\n\t" ::"s"(lds_ptr_sgpr), + "v"(global_offset_bytes), + "s"(src_resource) + : "memory"); +#else + // LDS pointer must be attributed with the LDS address space. + __attribute__((address_space(3))) uint32_t* lds_ptr = +#ifndef CK_CODE_GEN_RTC + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); +#else + reinterpret_cast<__attribute__((address_space(3))) uint32_t*>( + reinterpret_cast(lds_base_ptr + lds_offset)); +#endif + + llvm_amdgcn_raw_buffer_load_lds( + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); +#endif +} +#endif + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_ck_fp8.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_ck_fp8.hpp new file mode 100755 index 000000000000..b7af32d3dc59 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_ck_fp8.hpp @@ -0,0 +1,1743 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/random_gen.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/type.hpp" + +#ifndef CK_USE_FNUZ_FP8 +#define CK_USE_FNUZ_FP8 0 +#endif + +#ifndef CK_USE_OCP_FP8 +#define CK_USE_OCP_FP8 0 +#endif + +#if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \ + __HIP_DEVICE_COMPILE__ +#define CK_FP8_CVT_FAST_PATH 1 +#else +#define CK_FP8_CVT_FAST_PATH 0 +#endif + +#if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ +#define CK_OCP_FP8_CVT_FAST_PATH 1 +#else +#define CK_OCP_FP8_CVT_FAST_PATH 0 +#endif + +namespace ck { + +using f8_fnuz_t = _BitInt(8); +using bf8_fnuz_t = unsigned _BitInt(8); + +typedef unsigned char fp8_storage_t; + +/** + * \brief Describes FP8 interpretation + */ +enum class ck_fp8_interpretation_t +{ + CK_E4M3_OCP = 0, // OCP E4M3 + CK_E5M2_OCP = 1, // OCP E5M2 + CK_E4M3_FNUZ = 2, // FP8 + CK_E5M2_FNUZ = 3, // BF8 +}; + +/** + * \brief Describes saturation behavior + */ +enum class ck_saturation_t +{ + CK_NOSAT = 0, // No saturation - replace with NaN or Inf + CK_SATFINITE = 1, // Saturate to finite +}; + +namespace fp8_impl { + +typedef fp8_storage_t fp8x2_storage_t __attribute__((ext_vector_type(2))); +typedef _Float16 half2_t __attribute__((ext_vector_type(2))); +typedef ushort ushortx2_t __attribute__((ext_vector_type(2))); +typedef short shortx2_t __attribute__((ext_vector_type(2))); +typedef float float2_t __attribute__((ext_vector_type(2))); + +__host__ __device__ static inline constexpr bool fnuz_f8_is_nan(f8_fnuz_t a) +{ + return static_cast(a) == 0x80; +} +__host__ __device__ static inline constexpr bool fnuz_bf8_is_nan(bf8_fnuz_t a) +{ + return static_cast(a) == 0x80; +} + +__host__ __device__ static inline constexpr bool ocp_f8_is_nan(fp8_storage_t a) +{ + return (a & 0x7f) == 0x7f; +} +__host__ __device__ static inline constexpr bool ocp_bf8_is_nan(fp8_storage_t a) +{ + return (a & 0x7f) > 0x7c; +} + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L220 +// This has been modified to handle double types as well +template +__host__ __device__ static inline T cast_from_f8(fp8_storage_t x) +{ + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, "only half, float and double are supported"); + + constexpr int weo = is_half ? 5 : (is_float ? 8 : 11); + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 52); + + T fInf, fNegInf, fNaN, fNeg0, fmax, fmin; + if constexpr(is_half) + { + const unsigned short int ihInf = 0x7C00; + const unsigned short int ihNegInf = 0xFC00; + const unsigned short int ihNaN = 0x7C01; + const unsigned short int ihNeg0 = 0x8000; + /* Max number in e5m2 57344*/ + const unsigned short int ifmax = 0x7B00; + const unsigned short int ifmin = 0xFB00; + + fInf = bit_cast<_Float16>(ihInf); + fNegInf = bit_cast<_Float16>(ihNegInf); + fNaN = bit_cast<_Float16>(ihNaN); + fNeg0 = bit_cast<_Float16>(ihNeg0); + fmax = bit_cast<_Float16>(ifmax); + fmin = bit_cast<_Float16>(ifmin); + } + else if constexpr(is_float) + { + const unsigned int ifInf = 0x7F800000; + const unsigned int ifNegInf = 0xFF800000; + const unsigned int ifNaN = 0x7F800001; + const unsigned int ifNeg0 = 0x80000000; + /* Max number in e5m2 57344*/ + const unsigned int ifmax = 0x47600000; + const unsigned int ifmin = 0xC7600000; + + fInf = bit_cast(ifInf); + fNegInf = bit_cast(ifNegInf); + fNaN = bit_cast(ifNaN); + fNeg0 = bit_cast(ifNeg0); + fmax = bit_cast(ifmax); + fmin = bit_cast(ifmin); + } + else if constexpr(is_double) + { + const unsigned long long ifInf = 0x7FF0000000000000ull; + const unsigned long long ifNegInf = 0xFFF0000000000000ull; + const unsigned long long ifNaN = 0x7FF0000000000001ull; + const unsigned long long ifNeg0 = 0x8000000000000000ull; + /* Max number in e5m2 57344*/ + const unsigned long long ifmax = 0x40EC000000000000ull; + const unsigned long long ifmin = 0xC0EC000000000000ull; + + fInf = bit_cast(ifInf); + fNegInf = bit_cast(ifNegInf); + fNaN = bit_cast(ifNaN); + fNeg0 = bit_cast(ifNeg0); + fmax = bit_cast(ifmax); + fmin = bit_cast(ifmin); + } + + if(x == 0) + { + return 0; + } + + unsigned long long sign = x >> 7; + unsigned long long mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if constexpr(is_fnuz) + { + if(x == 0x80) + { + return fNaN; + } + } + else + { + if(x == 0x80) + { + return fNeg0; + } + if constexpr(we == 4) + { // e4m3 + if((x & 0x7F) == 0x7F) + { + return fNaN; + } + } + else if((x & 0x7C) == 0x7C) + { // e5m2 + if((x & 0x3) == 0) + { + if constexpr(clip) + { + return sign ? fmin : fmax; + } + return sign ? fNegInf : fInf; + } + return fNaN; + } + } + + typename ck::conditional_t< + sizeof(T) == 2, + unsigned short int, + typename ck::conditional_t> + retval; + + if constexpr(we == 5 && is_half && !is_fnuz) + { + retval = x << 8; + return bit_cast(retval); + } + + const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0); + + // subnormal input + if(exponent == 0) + { +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + __clz(mantissa) - (32 - wm); +#else + int sh = 1 + __builtin_clz(mantissa) - (32 - wm); +#endif + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1ull << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if constexpr(sizeof(T) == 2) + retval = (sign << 15) | (exponent << 10) | mantissa; + else if constexpr(sizeof(T) == 4) + retval = (sign << 31) | (exponent << 23) | mantissa; + else + retval = (sign << 63) | (static_cast(exponent) << 52) | mantissa; + + return bit_cast(retval); +} + +#if CK_FP8_CVT_FAST_PATH +template +static __host__ __device__ float cast_to_f32_from_f8(fp8_storage_t v) +{ + union + { + unsigned int i32val; + unsigned char i8val[4]; + } val; + val.i8val[0] = v; + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only FNUZ and OCP interpretations are supported"); + + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0); + } + else + { + return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0); + } +} + +template +static __device__ float2_t cast_to_f32_from_f8(fp8x2_storage_t v) +{ + const auto i16val = bit_cast(v); + + static_assert(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E4M3_OCP || + interpret == ck_fp8_interpretation_t::CK_E5M2_FNUZ || + interpret == ck_fp8_interpretation_t::CK_E5M2_OCP, + "Only FNUZ and OCP interpretations are supported"); + + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, false); + } + else + { + return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, false); + } +} +#endif + +} // namespace fp8_impl + +struct f8_ocp_t +{ + using data_type = fp8_storage_t; + data_type data; + + static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE; + static constexpr ck_fp8_interpretation_t default_interpret = + ck_fp8_interpretation_t::CK_E4M3_OCP; + + static constexpr unsigned int we = 4; // exponent width + static constexpr unsigned int wm = 3; // mantissa width + + __host__ __device__ constexpr bool operator==(const f8_ocp_t& other) const + { + return (data == other.data) && (fp8_impl::ocp_f8_is_nan(data) == false); // NaN != NaN + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator float() const +#else + __host__ explicit operator float() const +#endif + { +#if CK_OCP_FP8_CVT_FAST_PATH + return fp8_impl::cast_to_f32_from_f8(this->data); +#else + return fp8_impl::cast_from_f8( + this->data); // XXX: clip==false must be consistent with operator _Float16 +#endif + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator _Float16() const +#else + __host__ explicit operator _Float16() const +#endif + { +#if CK_OCP_FP8_CVT_FAST_PATH + return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); +#else + return fp8_impl::cast_from_f8<_Float16, wm, we, false>( + this->data); // XXX: clip==false must be consistent with operator float +#endif + } +}; + +struct bf8_ocp_t +{ + using data_type = fp8_storage_t; + data_type data; + + static constexpr ck_saturation_t default_saturation = ck_saturation_t::CK_SATFINITE; + static constexpr ck_fp8_interpretation_t default_interpret = + ck_fp8_interpretation_t::CK_E5M2_OCP; + + static constexpr unsigned int we = 5; // exponent width + static constexpr unsigned int wm = 2; // mantissa width + + __host__ __device__ constexpr bool operator==(const bf8_ocp_t& other) const + { + return (data == other.data) && (fp8_impl::ocp_bf8_is_nan(data) == false); // NaN != NaN + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator float() const + +#else + __host__ explicit operator float() const +#endif + { +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) + return fp8_impl::cast_to_f32_from_f8(this->data); +#else + return fp8_impl::cast_from_f8( + this->data); // XXX: clip==false must be consistent with operator _Float16 +#endif + } + +#if CK_USE_OCP_FP8 + __host__ __device__ explicit operator _Float16() const +#else + __host__ explicit operator _Float16() const +#endif + { +#if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) + return static_cast<_Float16>(fp8_impl::cast_to_f32_from_f8(this->data)); +#else + return fp8_impl::cast_from_f8<_Float16, wm, we, false>( + this->data); // XXX: clip==false must be consistent with operator float +#endif + } +}; + +template +__host__ __device__ static inline constexpr bool fp8_is_nan(T); + +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(f8_ocp_t a) +{ + return fp8_impl::ocp_f8_is_nan(a.data); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(bf8_ocp_t a) +{ + return fp8_impl::ocp_bf8_is_nan(a.data); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(f8_fnuz_t a) +{ + return fp8_impl::fnuz_f8_is_nan(a); +} +template <> +__host__ __device__ inline constexpr bool fp8_is_nan(bf8_fnuz_t a) +{ + return fp8_impl::fnuz_bf8_is_nan(a); +} + +template || is_same_v || + is_same_v || is_same_v, + bool> = true> +__host__ __device__ static inline constexpr bool fp8_is_inf(T) +{ + return false; +} +template <> +__host__ __device__ inline constexpr bool fp8_is_inf(bf8_ocp_t a) +{ + return (a.data & 0x7f) == 0x7c; +} + +namespace fp8_impl { + +// Assertions to check for supported conversion types +#define __fp8_impl_assert_ocp_support(interp) \ + { \ + if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP && \ + interp != ck_fp8_interpretation_t::CK_E5M2_OCP) \ + { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } +#define __fp8_impl_assert_fnuz_support(interp) \ + { \ + if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ && \ + interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ) \ + { \ + __hip_assert(false && "type is unsupported by current target device"); \ + } \ + } + +__host__ __device__ static inline void +__is_interpret_supported([[maybe_unused]] ck_fp8_interpretation_t interp) +{ +#if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ +#if CK_USE_OCP_FP8 + __fp8_impl_assert_ocp_support(interp); +#endif +#if CK_USE_FNUZ_FP8 + __fp8_impl_assert_fnuz_support(interp); +#endif +#endif +} + +#if defined(__gfx950__) +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + half2_t half_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.i32val = + __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_f16(_Float16 v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + } + + val.half_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_f16(half2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_FP16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_f16(v[0], rng), + cast_to_f8_from_f16(v[1], rng)}; +#else + ignore = rng; + + union + { + half2_t half_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.half_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0); + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0); + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr unsigned int i32val = 0; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16( + i32val, val.bhalf_vec[0], rng, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return fp8x2_storage_t{ + cast_to_f8_from_bf16(v[0], rng), + cast_to_f8_from_bf16(v[1], rng)}; +#else + ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = + ushort((bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +#endif +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8_storage_t cast_to_f8_from_bf16(ushort v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + unsigned int i32val; + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec[0] = v; + + if constexpr(saturate) + { + if((val.i32val & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return val.i8val[0]; +} + +template = false, + ck::enable_if_t = false> +static __device__ fp8x2_storage_t cast_to_f8_from_bf16(ushortx2_t v, unsigned int rng = 0) +{ + ignore = rng; + + union + { + ushortx2_t bhalf_vec; + shortx2_t i16_vec; + fp8_storage_t i8val[4]; + } val; + + constexpr shortx2_t i16x2val = {0, 0}; + val.bhalf_vec = v; + + if constexpr(saturate) + { + if((val.i16_vec[0] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[0] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + if((val.i16_vec[1] & 0x7FFF) != 0x7FFF) + { + val.bhalf_vec[1] = ushort( + (bit_cast(__builtin_amdgcn_fmed3f( + bit_cast(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >> + 16)); // convert to float and back + } + } + + val.i16_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec, /* scale */ 1.f, 0); + + return fp8x2_storage_t{val.i8val[0], val.i8val[1]}; +} +#endif // defined(__gfx950__) + +#if CK_FP8_CVT_FAST_PATH +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_float8.h#L79 +template +static __device__ fp8_storage_t cast_to_f8_from_f32(float v, unsigned int rng = 0) +{ + fp8_storage_t i8data; + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; // NOTE: not endian independent + } val; + + unsigned int ival = 0; + val.fval = v; + + if constexpr(saturate) + { + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { // OCP type + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0); + } + } + else + { + if((val.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0); + } + } + } + + if constexpr(stochastic_rounding) + { + ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0) + : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos + val.i32val = ival; + i8data = val.i8val[0]; // little endian + } + else + { // RNE CVT + ival = (interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false) + : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, + val.fval, + ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + } + return i8data; +} + +template +static __device__ fp8x2_storage_t cast_to_f8_from_f32(float2_t v, unsigned int rng = 0) +{ + if constexpr(stochastic_rounding) + { + // there is no packed conversion with SR, so convert one element at a time + return fp8x2_storage_t{ + cast_to_f8_from_f32(v[0], rng), + cast_to_f8_from_f32(v[1], rng)}; + } + else + { + union + { + float fval; + unsigned int i32val; + unsigned char i8val[4]; + } val0, val1; + + val0.fval = v[0]; + val1.fval = v[1]; + + unsigned int ival = 0; + + if constexpr(saturate) + { + if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0); + } + } + else if constexpr(interpret == ck_fp8_interpretation_t::CK_E4M3_OCP) + { // OCP type + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0); + } + } + else + { + if((val0.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0); + } + if((val1.i32val & 0x7F800000) != 0x7F800000) + { /// propagate NAN/INF, no clipping + val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0); + } + } + } + + // RNE CVT + if constexpr((interpret == ck_fp8_interpretation_t::CK_E4M3_FNUZ) || + (interpret == ck_fp8_interpretation_t::CK_E4M3_OCP)) + { + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, false); + } + else + { + ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, false); + } + + val0.i32val = ival; + + return fp8x2_storage_t{val0.i8val[0], val0.i8val[1]}; + } +} +#endif // CK_FP8_CVT_FAST_PATH + +// The conversion function is from rocblas +// https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 +// This has been modified to add double types conversion as well +template +__host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rng = 0) +{ + constexpr bool is_half = __hip_internal::is_same::value; + constexpr bool is_float = __hip_internal::is_same::value; + constexpr bool is_double = __hip_internal::is_same::value; + static_assert(is_half || is_float || is_double, + "Only half, float and double can be cast to f8"); + + constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); + + using T_bitwise = typename ck::conditional_t< + sizeof(T) == 2, + unsigned short int, + typename ck::conditional_t>; + T_bitwise x_bitwise = bit_cast(_x); + + unsigned long long x{x_bitwise}; + + unsigned long long head, mantissa; + int exponent, bias; + unsigned int sign; + unsigned long long fInf, mask; + + if constexpr(sizeof(T) == 8) + { + head = x & 0xFFF0000000000000ull; + mantissa = x & 0xFFFFFFFFFFFFFull; + exponent = (head >> 52) & 0x7FF; + sign = head >> 63; + bias = 1023; + fInf = 0x7FF0000000000000ull; + mask = 0x7FFFFFFFFFFFFFFFull; + } + else if constexpr(sizeof(T) == 4) + { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + fInf = 0x7F800000; + mask = 0x7FFFFFFF; + } + else + { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + fInf = 0x7C00; + mask = 0x7FFF; + } + unsigned int signed_inf = 0; + unsigned int nan = 0; + if constexpr(is_fnuz) + { + signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80; + nan = 0x80; + } + else + { + if constexpr(we == 4) + { // e4m3 + signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f); + } + else + { // e5m2 + signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c); + } + nan = (sign << 7) + 0x7f; + } + // Max values + unsigned long long ifmax = 0; + if constexpr(sizeof(T) == 8) + { + if constexpr(we == 5) + { // 57344 + ifmax = 0x40EC000000000000ull; + } + else + { + if constexpr(is_fnuz) + { // 240 + ifmax = 0x406E000000000000ull; + } + else + { // 448 + ifmax = 0x407C000000000000ull; + } + } + } + else if(sizeof(T) == 4) + { + if constexpr(we == 5) + { + ifmax = 0x47600000; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x43700000; + } + else + { + ifmax = 0x43E00000; + } + } + } + else + { + if constexpr(we == 5) + { + ifmax = 0x7B00; + } + else + { + if constexpr(is_fnuz) + { + ifmax = 0x5B80; + } + else + { + ifmax = 0x5F00; + } + } + } + // Deal with inf and NaNs + if((x & fInf) == fInf) + { + if constexpr(is_fnuz) + return signed_inf; + + return mantissa != 0 ? nan : signed_inf; + } + + if((x & mask) > ifmax) + { + return signed_inf; + } + + if(x == 0) + { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0); + const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we + mostly concern fp16 here. In this case, f8 is usually in denormal. But there + could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has + exponent bias 16. It means that there are some numbers in fp16 denormal but they + are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers + where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 + (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= f8_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal + range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implicit 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no difference + // for this case, act_exponent could be larger. Just + // that it does not need shift mantissa + } + mantissa += (1ull << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) == + (1ull << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part and + make something not midpoint look like midpoint. For example, the fp16 number + 0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right + by 4 bits, it would look like midpoint. + */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1ull << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1; + bool odd = + mantissa & (1ull << (mfmt - wm)); // if the least significant bit that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(f8_exponent == 0) + { + if((1ull << mfmt) & mantissa) + { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } + else + { + if((1ull << (mfmt + 1)) & mantissa) + { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - 1; + if(f8_exponent > max_exp) + { + if constexpr(clip) + { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + if(f8_exponent == 0 && mantissa == 0) + return is_fnuz ? 0 : (sign << 7); + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +/** + * \brief convert float to @p fp8_storage_t + * + * \tparam interp interpretation of fp8 + * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param f float number + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) + } + return cast_to_f8_from_f32( + f, rng); +#else +#if CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ +#else +__host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) +{ +#endif + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f); +#else + rng = prand_generator(reinterpret_cast(&f), f); +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) + } + + if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_FNUZ) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_OCP) + { + return cast_to_f8(f, rng); + } + else if constexpr(interp == ck_fp8_interpretation_t::CK_E5M2_OCP) + { + return cast_to_f8(f, rng); + } + else + { + __hip_assert(false && "FP8 type is not supported by current target device"); + return 0; + } +#endif // CK_FP8_CVT_FAST_PATH +} + +/** + * \brief convert vector of 2 floats to vector of 2 @p fp8_storage_t + * + * \tparam interp interpretation of fp8 + * \tparam sat saturation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param f vector of 2 floats + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH +__device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&f), f[0]); +#else + rng = prand_generator(reinterpret_cast(&f), f[0]); +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) + } + return cast_to_f8_from_f32( + f, rng); +#else +#if CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#else +__host__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) +{ +#endif // CK_USE_OCP_FP8 + return fp8x2_storage_t{cvt_float_to_fp8(f[0]), + cvt_float_to_fp8(f[1])}; +#endif // CK_FP8_CVT_FAST_PATH +} + +/** + * \brief convert _Float16 to @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x _Float16 value + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) +#else +__host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x); +#else + rng = prand_generator(reinterpret_cast(&x), x); +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + static_cast(x)); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 _Float16 to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 _Float16 + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_half_t_to_fp8(const half2_t x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) + } +#if defined(__gfx950__) + return cast_to_f8_from_f16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + float2_t{static_cast(x[0]), static_cast(x[1])}); +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert bhalf_t to @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x bhalf_t value + * \return fp8_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#else +__host__ static inline fp8_storage_t cvt_bhalf_t_to_fp8(const ushort x) +#endif +{ + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x)); +#else + rng = prand_generator(reinterpret_cast(&x), static_cast(x)); +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + bit_cast(uint32_t{x} << 16)); // convert value to float +#endif // defined(__gfx950__) + } +} + +/** + * \brief convert vector of 2 bhalf_t to vector of 2 @p fp8_storage_t + * + * \tparam sat saturation of fp8 + * \tparam interp interpretation of fp8 + * \tparam stochastic_rounding switch between RNE and SR + * \param x vector of 2 bhalf_t + * \return fp8x2_storage_t + */ +template +#if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 +__host__ __device__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#else +__host__ static inline fp8x2_storage_t cvt_bhalf_t_to_fp8(const ushortx2_t x) +#endif +{ +#if CK_WORKAROUND_BF16_TO_FP8_CONVERSION + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#else // CK_WORKAROUND_BF16_TO_FP8_CONVERSION + { + __is_interpret_supported(interp); + uint32_t rng = 0; + if constexpr(stochastic_rounding) + { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#else + rng = prand_generator(reinterpret_cast(&x), + static_cast(x[0])); +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) + } +#if defined(__gfx950__) + return cast_to_f8_from_bf16(x, rng); +#else + ignore = rng; + return cvt_float_to_fp8( + float2_t{bit_cast(uint32_t{x[0]} << 16), + bit_cast(uint32_t{x[1]} << 16)}); // convert values to float +#endif // defined(__gfx950__) + } +#endif // CK_WORKAROUND_BF16_TO_FP8_CONVERSION +} + +} // namespace fp8_impl + +#if CK_USE_OCP_FP8 +using f8_t = f8_ocp_t; +using bf8_t = bf8_ocp_t; +#define CK_FP8_TYPE_FNUZ 0 +#define CK_FP8_TYPE_OCP 1 +#else +using f8_t = f8_fnuz_t; +using bf8_t = bf8_fnuz_t; +#define CK_FP8_TYPE_FNUZ 1 +#define CK_FP8_TYPE_OCP 0 +#endif + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_gemm_dpp.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_gemm_dpp.hpp new file mode 100755 index 000000000000..a28292dade3d --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_gemm_dpp.hpp @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/inner_product_dpp8.hpp" + +namespace ck { + +namespace dpp8 { + +template +struct dpp_datatypes; + +template <> +struct dpp_datatypes +{ + // Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a + // single instruction. + using a_dtype = half_t; + using b_dtype = half_t; + using c_dtype = float; + static constexpr index_t k_per_instr = 2; +}; + +template +struct DppLanegroupGemm +{ + using datatypes_conf = dpp_datatypes; + using ADataType = typename datatypes_conf::a_dtype; + using BDataType = typename datatypes_conf::b_dtype; + using CDataType = typename datatypes_conf::c_dtype; + + __device__ void Run(const AVecDataType& a_vec, const BVecDataType& b_vec, CVecDataType& c_vec) + { + constexpr index_t num_c_elems_per_thread = ShareA ? MPerThread : NPerThread; + + const vector_type a_vector{a_vec}; + const vector_type b_vector{b_vec}; + + static_for<0, num_c_elems_per_thread, 1>{}([&](auto c_idx) { + float c = c_vec.template AsType()(c_idx); + // Next `c_idx` implies that we need to pull data from the next lane. + constexpr index_t source_lane = c_idx; + static_for<0, KPerThread / datatypes_conf::k_per_instr, 1>{}([&](auto k_chunk) { + const auto a_k_vec = a_vector.template AsType()[k_chunk]; + const auto b_k_vec = b_vector.template AsType()[k_chunk]; + ck::dpp8:: + inner_product_dpp( + a_k_vec, b_k_vec, c); + }); + c_vec.template AsType()(c_idx) = c; + }); + } +}; + +} // namespace dpp8 + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_inline_asm.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_inline_asm.hpp new file mode 100755 index 000000000000..0ed60df2c3ec --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_inline_asm.hpp @@ -0,0 +1,435 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_INLINE_ASM_HPP +#define CK_AMD_INLINE_ASM_HPP + +#include "c_style_pointer_cast.hpp" +#include "dtype_vector.hpp" + +// TODO: deprecate all amd_assembly_outer_product_xxx + +namespace ck { + +inline __device__ int amd_assembly_and_b32(int a, int b) +{ + int c; + asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); + return c; +} + +inline __device__ int amd_assembly_and_or_b32(int a, int b, int d) +{ + int c; + asm volatile("v_and_or_b32 %0, %1, %2, %3" : "=v"(c) : "v"(a), "v"(b), "v"(d)); + return c; +} + +inline __device__ half2_t amd_assembly_pk_fma_f16(half2_t a, half2_t b, half2_t c) +{ + half2_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +} + +inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) +{ + half2_t c; + asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); + return c; +} + +inline __device__ float amd_assemble_cvt_f32_i4(int b) +{ + float a; + asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b)); + return a; +} + +inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3) +{ + f8x4_t a; + asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n" + "v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n" + : "=v"(a) + : "v"(b0), "v"(b1), "v"(b2), "v"(b3)); + return a; +} + +inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a) +{ + uint32_t i4x8 = static_cast(a); + uint32_t fp8x4_0; + uint32_t fp8x4_1; + float tmp_0, tmp_1, tmp_2; + + asm volatile("v_cvt_off_f32_i4 %[v_tmp_0], %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_2\n" + "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_src], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_src], src0_sel:BYTE_3\n" + "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1]\n" + "v_lshrrev_b32 %[v_tmp_2], 4, %[v_src]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2]\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_2\n" + "v_cvt_pk_fp8_f32 %[v_dst_0], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n" + "v_cvt_off_f32_i4 %[v_tmp_0], %[v_tmp_2], src0_sel:BYTE_1\n" + "v_cvt_off_f32_i4 %[v_tmp_1], %[v_tmp_2], src0_sel:BYTE_3\n" + "v_cvt_pk_fp8_f32 %[v_dst_1], %[v_tmp_0], %[v_tmp_1], op_sel:[0, 0, 1]\n" + : [v_tmp_0] "+v"(tmp_0), + [v_tmp_1] "+v"(tmp_1), + [v_tmp_2] "+v"(tmp_2), + [v_dst_0] "+v"(fp8x4_0), + [v_dst_1] "+v"(fp8x4_1), + [v_src] "+v"(i4x8) + :); + + return bit_cast(((static_cast(fp8x4_1) << 32) | fp8x4_0)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_fmac_f32 %0, %2, %3 \n \ + v_fmac_f32 %1, %2, %4 \n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4( + float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3) +{ + asm volatile("\n \ + v_fmac_f32 %0, %4, %5 \n \ + v_fmac_f32 %1, %4, %6 \n \ + v_fmac_f32 %2, %4, %7 \n \ + v_fmac_f32 %3, %4, %8 \n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %3, %0\n \ + v_dot2_f32_f16 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %2, %4, %0\n \ + v_dot2_f32_f16 %1, %2, %6, %1\n \ + v_dot2_f32_f16 %0, %3, %5, %0\n \ + v_dot2_f32_f16 %1, %3, %7, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "0"(c0), + "1"(c1)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half2_t a, + half2_t b0, + half2_t b1, + half2_t b2, + half2_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %5, %0\n \ + v_dot2_f32_f16 %1, %4, %6, %1\n \ + v_dot2_f32_f16 %2, %4, %7, %2\n \ + v_dot2_f32_f16 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3)); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(half4_t a, + half4_t b0, + half4_t b1, + half4_t b2, + half4_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half2_t* p_a_half2 = c_style_pointer_cast(&a); + const half2_t* p_b0_half2 = c_style_pointer_cast(&b0); + const half2_t* p_b1_half2 = c_style_pointer_cast(&b1); + const half2_t* p_b2_half2 = c_style_pointer_cast(&b2); + const half2_t* p_b3_half2 = c_style_pointer_cast(&b3); + + // do dot2 two times + asm volatile("\n \ + v_dot2_f32_f16 %0, %4, %6, %0\n \ + v_dot2_f32_f16 %1, %4, %8, %1\n \ + v_dot2_f32_f16 %2, %4, %10, %2\n \ + v_dot2_f32_f16 %3, %4, %12, %3\n \ + v_dot2_f32_f16 %0, %5, %7, %0\n \ + v_dot2_f32_f16 %1, %5, %9, %1\n \ + v_dot2_f32_f16 %2, %5, %11, %2\n \ + v_dot2_f32_f16 %3, %5, %13, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(p_a_half2[0]), + "v"(p_a_half2[1]), + "v"(p_b0_half2[0]), + "v"(p_b0_half2[1]), + "v"(p_b1_half2[0]), + "v"(p_b1_half2[1]), + "v"(p_b2_half2[0]), + "v"(p_b2_half2[1]), + "v"(p_b3_half2[0]), + "v"(p_b3_half2[1]), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +} + +__device__ void amd_assembly_outer_product_1x4(half8_t a, + half8_t b0, + half8_t b1, + half8_t b2, + half8_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + + // TODO remove pointer casting + const half4_t* p_a_half4 = c_style_pointer_cast(&a); + const half4_t* p_b0_half4 = c_style_pointer_cast(&b0); + const half4_t* p_b1_half4 = c_style_pointer_cast(&b1); + const half4_t* p_b2_half4 = c_style_pointer_cast(&b2); + const half4_t* p_b3_half4 = c_style_pointer_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3); +} + +__device__ void amd_assembly_outer_product_1x4(half16_t a, + half16_t b0, + half16_t b1, + half16_t b2, + half16_t b3, + float& c0, + float& c1, + float& c2, + float& c3) +{ + // TODO remove pointer casting + const half8_t* p_a_half8 = c_style_pointer_cast(&a); + const half8_t* p_b0_half8 = c_style_pointer_cast(&b0); + const half8_t* p_b1_half8 = c_style_pointer_cast(&b1); + const half8_t* p_b2_half8 = c_style_pointer_cast(&b2); + const half8_t* p_b3_half8 = c_style_pointer_cast(&b3); + + amd_assembly_outer_product_1x4( + p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3); + + amd_assembly_outer_product_1x4( + p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3); +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +__device__ void +amd_assembly_outer_product_1x2(int8x4_t a, int8x4_t b0, int8x4_t b1, int32_t& c0, int32_t& c1) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %2, %3, %0\n \ + v_dot4_i32_i8 %1, %2, %4, %1\n \ + " + : "=v"(c0), "=v"(c1) + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), + "0"(c0), + "1"(c1)); +#else + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); +#endif +} + +// c0 += inner_product(a, b0) +// c1 += inner_product(a, b1) +// c2 += inner_product(a, b2) +// c3 += inner_product(a, b3) +__device__ void amd_assembly_outer_product_1x4(int8x4_t a, + int8x4_t b0, + int8x4_t b1, + int8x4_t b2, + int8x4_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ +#if 1 + asm volatile("\n \ + v_dot4_i32_i8 %0, %4, %5, %0\n \ + v_dot4_i32_i8 %1, %4, %6, %1\n \ + v_dot4_i32_i8 %2, %4, %7, %2\n \ + v_dot4_i32_i8 %3, %4, %8, %3\n \ + " + : "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) + : "v"(bit_cast(a)), + "v"(bit_cast(b0)), + "v"(bit_cast(b1)), + "v"(bit_cast(b2)), + "v"(bit_cast(b3)), + "0"(c0), + "1"(c1), + "2"(c2), + "3"(c3)); +#else + c0 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b0), c0, false); + c1 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b1), c1, false); + c2 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b2), c2, false); + c3 = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b3), c3, false); +#endif +} + +__device__ void amd_assembly_outer_product_1x4(int8x8_t a, + int8x8_t b0, + int8x8_t b1, + int8x8_t b2, + int8x8_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); +} + +__device__ void amd_assembly_outer_product_1x4(int8x16_t a, + int8x16_t b0, + int8x16_t b1, + int8x16_t b2, + int8x16_t b3, + int32_t& c0, + int32_t& c1, + int32_t& c2, + int32_t& c3) + +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I0], + vector_type{b0}.AsType()[I0], + vector_type{b1}.AsType()[I0], + vector_type{b2}.AsType()[I0], + vector_type{b3}.AsType()[I0], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I1], + vector_type{b0}.AsType()[I1], + vector_type{b1}.AsType()[I1], + vector_type{b2}.AsType()[I1], + vector_type{b3}.AsType()[I1], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I2], + vector_type{b0}.AsType()[I2], + vector_type{b1}.AsType()[I2], + vector_type{b2}.AsType()[I2], + vector_type{b3}.AsType()[I2], + c0, + c1, + c2, + c3); + + amd_assembly_outer_product_1x4(vector_type{a}.AsType()[I3], + vector_type{b0}.AsType()[I3], + vector_type{b1}.AsType()[I3], + vector_type{b2}.AsType()[I3], + vector_type{b3}.AsType()[I3], + c0, + c1, + c2, + c3); +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_lds.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_lds.hpp new file mode 100755 index 000000000000..c218fded9678 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_lds.hpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/dynamic_buffer.hpp" +#include "ck/utility/math.hpp" + +namespace ck { + +namespace lds_utils { + +/** \brief Allocate a given number of buffers in LDS and return them as a tuple. + * + * \tparam DataType Data type of elements to be stored in LDS. + * \tparam NumBuffers Number of buffers to be allocated. + * \param lds_ptr Address of the beginning of LDS space. + * \param num_elems_per_buffer Number of elements to allocate per single buffer. + * \param start_offset_elems Number of elements to move from the start of LDS for the allocation of + * the first buffer. \param lds_alignment Alignment of every buffer allocation given as a number of + * elements. \return Tuple of dynamic buffers representing memory allocated in LDS. + */ +template +__device__ static auto AllocateLdsBuffers(void* lds_ptr, + int32_t num_elems_per_buffer, + int32_t start_offset_elems, + int32_t lds_alignment) +{ + const DataType* lds_start = static_cast(lds_ptr) + start_offset_elems; + const int32_t single_buffer_offset = + math::integer_least_multiple(num_elems_per_buffer, lds_alignment); + return generate_tuple( + [&](auto i) { + const int32_t local_offset = i * single_buffer_offset; + return make_dynamic_buffer(lds_start + local_offset, + num_elems_per_buffer); + }, + Number{}); +} + +} // namespace lds_utils +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_smfmac.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_smfmac.hpp new file mode 100755 index 000000000000..8b6b094ff23a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_smfmac.hpp @@ -0,0 +1,99 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#pragma once + +namespace ck { + +template +struct intrin_smfmac_f32_16x16x32f16; + +// for every smfmac instruction if CBSZ[1:0]=0, ABID[1:0] selects one of four 8-bit sets of sparse +// indices from reg_idx +template <> +struct intrin_smfmac_f32_16x16x32f16<16, 16> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_16x16x32bf16; + +template <> +struct intrin_smfmac_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_32x32x16f16; + +template <> +struct intrin_smfmac_f32_32x32x16f16<32, 32> +{ + template + __device__ static void + Run(const half4_t& reg_a, const half8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +template +struct intrin_smfmac_f32_32x32x16bf16; + +template <> +struct intrin_smfmac_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void + Run(const bhalf4_t& reg_a, const bhalf8_t& reg_b, const index_t& reg_idx, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_smfmac_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], reg_idx, 0, abid); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; + ignore = reg_idx; +#endif + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_wave_read_first_lane.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_wave_read_first_lane.hpp new file mode 100755 index 000000000000..3604712837de --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_wave_read_first_lane.hpp @@ -0,0 +1,161 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/math.hpp" + +#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) +#include +#include +#include +#include +#endif + +namespace ck { +namespace detail { + +template +struct get_carrier; + +template <> +struct get_carrier<1> +{ + using type = uint8_t; +}; + +template <> +struct get_carrier<2> +{ + using type = uint16_t; +}; + +template <> +struct get_carrier<3> +{ + using type = class carrier + { + using value_type = uint32_t; + + Array bytes; + static_assert(sizeof(bytes) <= sizeof(value_type)); + + // replacement of host std::copy_n() + template + __device__ static OutputIterator copy_n(InputIterator from, Size size, OutputIterator to) + { + if(0 < size) + { + *to = *from; + ++to; + for(Size count = 1; count < size; ++count) + { + *to = *++from; + ++to; + } + } + + return to; + } + + // method to trigger template substitution failure + __device__ carrier(const carrier& other) noexcept + { + copy_n(other.bytes.begin(), bytes.Size(), bytes.begin()); + } + + public: + __device__ carrier& operator=(value_type value) noexcept + { + copy_n(reinterpret_cast(&value), bytes.Size(), bytes.begin()); + + return *this; + } + + __device__ operator value_type() const noexcept + { + ck::byte result[sizeof(value_type)]; + + copy_n(bytes.begin(), bytes.Size(), result); + + return *reinterpret_cast(result); + } + }; +}; +static_assert(sizeof(get_carrier<3>::type) == 3); + +template <> +struct get_carrier<4> +{ + using type = uint32_t; +}; + +template +using get_carrier_t = typename get_carrier::type; + +} // namespace detail + +__device__ inline uint32_t amd_wave_read_first_lane(uint32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +__device__ inline int32_t amd_wave_read_first_lane(int32_t value) +{ + return __builtin_amdgcn_readfirstlane(value); +} + +__device__ inline int64_t amd_wave_read_first_lane(int64_t value) +{ + constexpr unsigned object_size = sizeof(int64_t); + constexpr unsigned second_part_offset = object_size / 2; + auto* const from_obj = reinterpret_cast(&value); + alignas(int64_t) ck::byte to_obj[object_size]; + + using Sgpr = uint32_t; + + *reinterpret_cast(to_obj) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj)); + *reinterpret_cast(to_obj + second_part_offset) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj + second_part_offset)); + + return *reinterpret_cast(to_obj); +} + +template && ck::is_trivially_copyable_v>> +__device__ auto amd_wave_read_first_lane(const Object& obj) +{ + using Size = unsigned; + constexpr Size SgprSize = 4; + constexpr Size ObjectSize = sizeof(Object); + + auto* const from_obj = reinterpret_cast(&obj); + alignas(Object) ck::byte to_obj[ObjectSize]; + + constexpr Size RemainedSize = ObjectSize % SgprSize; + constexpr Size CompleteSgprCopyBoundary = ObjectSize - RemainedSize; + for(Size offset = 0; offset < CompleteSgprCopyBoundary; offset += SgprSize) + { + using Sgpr = detail::get_carrier_t; + + *reinterpret_cast(to_obj + offset) = + amd_wave_read_first_lane(*reinterpret_cast(from_obj + offset)); + } + + if constexpr(0 < RemainedSize) + { + using Carrier = detail::get_carrier_t; + + *reinterpret_cast(to_obj + CompleteSgprCopyBoundary) = amd_wave_read_first_lane( + *reinterpret_cast(from_obj + CompleteSgprCopyBoundary)); + } + + /// NOTE: Implicitly start object lifetime. It's better to use std::start_lifetime_at() in this + /// scenario + return *reinterpret_cast(to_obj); +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_wmma.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_wmma.hpp new file mode 100755 index 000000000000..e14c0d62a8f7 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_wmma.hpp @@ -0,0 +1,441 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_AMD_WMMA_HPP +#define CK_AMD_WMMA_HPP + +#include "ck/utility/amd_inline_asm.hpp" +#include "data_type.hpp" +// TODO: Add arch limitation +namespace ck { + +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx11_generic__) +#define __gfx11__ +#endif + +#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) +#define __gfx12__ +#endif + +/********************************WAVE32 MODE***********************************************/ + +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: fp16, dst: fp16 +template +struct intrin_wmma_f16_16x16x16_f16_w32; + +template +struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: bf16 +template +struct intrin_wmma_bf16_16x16x16_bf16_w32; + +template +struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result +#if defined(__gfx11__) || defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +/********************************WAVE64 MODE***********************************************/ + +template +struct intrin_wmma_f32_16x16x16_f16_w64; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w64<16, 16> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w64; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: fp16, dst: fp16 +template +struct intrin_wmma_f16_16x16x16_f16_w64; + +template +struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel> +{ + template + __device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: bf16 +template +struct intrin_wmma_bf16_16x16x16_bf16_w64; + +template +struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel> +{ + template + __device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c) + { + // opsel usage + // false: D0.[0:15] = result + // true : D0.[16:31]= result +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], Opsel); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w64; + +template +struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx11__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w64( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// gfx12 +/********************************WAVE32 MODE***********************************************/ + +// src: fp16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { + // * Inline assembly need to elimate the duplicated data load, compiler won't help you + // delete them. + // amd_assembly_wmma_f32_16x16x16_f16_w32( + // reg_a, reg_b, reg_c.template AsType()(Number<0>{})); +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf16, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf16_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: iu8, dst: i32 +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12; + +template +struct intrin_wmma_i32_16x16x16_iu8_w32_gfx12<16, 16, neg_a, neg_b, clamp> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( + neg_a, + bit_cast(reg_a), + neg_b, + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + clamp); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: f8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: f8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, f8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +// src: bf8, bf8, dst: fp32 +template +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12; + +template <> +struct intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx12__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}]); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/amd_xdlops.hpp b/csrc/rocm/ck_extensions/include/ck/utility/amd_xdlops.hpp new file mode 100755 index 000000000000..02a7a72b8c92 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/amd_xdlops.hpp @@ -0,0 +1,1639 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "ck/utility/dtype_fp64.hpp" + +namespace ck { +// Define the common macro for MI300 models +#if defined(__gfx942__) || defined(__gfx950__) +#define __gfx94__ +#endif + +// fp32 +template +struct intrin_mfma_f32_32x32x1f32; + +template <> +struct intrin_mfma_f32_32x32x1f32<64, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x1f32<32, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x2f32; + +template <> +struct intrin_mfma_f32_32x32x2f32<32, 32> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f32; + +template <> +struct intrin_mfma_f32_16x16x4f32<16, 16> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x1f32; + +template <> +struct intrin_mfma_f32_16x16x1f32<16, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x1f32; + +template <> +struct intrin_mfma_f32_4x4x1f32<4, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x1f32<8, 64> +{ + template + __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// fp16 +template +struct intrin_mfma_f32_32x32x4f16; + +template <> +struct intrin_mfma_f32_32x32x4f16<64, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); + } +}; + +template <> +struct intrin_mfma_f32_32x32x4f16<32, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x16f16; + +template <> +struct intrin_mfma_f32_32x32x16f16<32, 32> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32f16; + +template <> +struct intrin_mfma_f32_16x16x32f16<16, 16> +{ + template + __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_32x32x8f16; + +template <> +struct intrin_mfma_f32_32x32x8f16<32, 32> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16f16; + +template <> +struct intrin_mfma_f32_16x16x16f16<16, 16> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x4f16; + +template <> +struct intrin_mfma_f32_16x16x4f16<16, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); + } +}; + +template +struct intrin_mfma_f32_4x4x4f16; + +template <> +struct intrin_mfma_f32_4x4x4f16<4, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + } +}; + +template <> +struct intrin_mfma_f32_4x4x4f16<8, 64> +{ + template + __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); + } +}; + +// bfp16 +template +struct intrin_mfma_f32_32x32x16bf16; + +template <> +struct intrin_mfma_f32_32x32x16bf16<32, 32> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_16x16x32bf16; + +template <> +struct intrin_mfma_f32_16x16x32bf16<16, 16> +{ + template + __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_f32_32x32x8bf16_1k; + +template <> +struct intrin_mfma_f32_32x32x8bf16_1k<32, 32> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x16bf16_1k; + +template <> +struct intrin_mfma_f32_16x16x16bf16_1k<16, 16> +{ + template + __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_32x32x4bf16; + +template <> +struct intrin_mfma_f32_32x32x4bf16<32, 32> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_f32_16x16x8bf16; + +template <> +struct intrin_mfma_f32_16x16x8bf16<16, 16> +{ + template + __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); + } +}; + +template +struct intrin_mfma_i32_32x32x8i8; + +template <> +struct intrin_mfma_i32_32x32x8i8<32, 32> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x16i8; + +template <> +struct intrin_mfma_i32_16x16x16i8<16, 16> +{ + template + __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_32x32x32i8; + +template <> +struct intrin_mfma_i32_32x32x32i8<32, 32> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_i32_16x16x64i8; + +template <> +struct intrin_mfma_i32_16x16x64i8<16, 16> +{ + template + __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif // defined(__gfx950__) + } +}; + +template +struct intrin_mfma_i32_32x32x16i8; + +template <> +struct intrin_mfma_i32_32x32x16i8<32, 32> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_i32_16x16x32i8; + +template <> +struct intrin_mfma_i32_16x16x32i8<16, 16> +{ + template + __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c) + { + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); + } +}; + +template +struct intrin_mfma_f64_16x16x4f64; + +template <> +struct intrin_mfma_f64_16x16x4f64<16, 16> +{ + template + __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) + { +#if defined(__gfx90a__) || defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x64f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 32x32x64 submatrices for f8, f6, +/// and f4 data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_32x32x64f8f6f4; + +template +struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 4, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_scale_f32_16x16x128f8f6f4; + +template +struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB> +{ + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, + const int32_t& scale_a, + const bf8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, + const int32_t& scale_a, + const f8x32_t& reg_b, + const int32_t& scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, + const int32_t scale_a, + const f6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x16x2_t& reg_a, + const int32_t scale_a, + const f6x16x2_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + using arg_type = int32x8_t; + arg_type arg_a{ + static_cast(reg_a.template AsType()[Number<0>{}][0]), + static_cast(reg_a.template AsType()[Number<0>{}][1]), + static_cast(reg_a.template AsType()[Number<0>{}][2]), + static_cast(reg_a.template AsType()[Number<1>{}][0]), + static_cast(reg_a.template AsType()[Number<1>{}][1]), + static_cast(reg_a.template AsType()[Number<1>{}][2]), + 0, + 0}; + arg_type arg_b{ + static_cast(reg_b.template AsType()[Number<0>{}][0]), + static_cast(reg_b.template AsType()[Number<0>{}][1]), + static_cast(reg_b.template AsType()[Number<0>{}][2]), + static_cast(reg_b.template AsType()[Number<1>{}][0]), + static_cast(reg_b.template AsType()[Number<1>{}][1]), + static_cast(reg_b.template AsType()[Number<1>{}][2]), + 0, + 0}; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_a, + arg_b, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, + const int32_t scale_a, + const bf6x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x16x2_t& reg_a, + const int32_t scale_a, + const bf6x16x2_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + using arg_type = int32x8_t; + arg_type arg_a{ + static_cast(reg_a.template AsType()[Number<0>{}][0]), + static_cast(reg_a.template AsType()[Number<0>{}][1]), + static_cast(reg_a.template AsType()[Number<0>{}][2]), + static_cast(reg_a.template AsType()[Number<1>{}][0]), + static_cast(reg_a.template AsType()[Number<1>{}][1]), + static_cast(reg_a.template AsType()[Number<1>{}][2]), + 0, + 0}; + arg_type arg_b{ + static_cast(reg_b.template AsType()[Number<0>{}][0]), + static_cast(reg_b.template AsType()[Number<0>{}][1]), + static_cast(reg_b.template AsType()[Number<0>{}][2]), + static_cast(reg_b.template AsType()[Number<1>{}][0]), + static_cast(reg_b.template AsType()[Number<1>{}][1]), + static_cast(reg_b.template AsType()[Number<1>{}][2]), + 0, + 0}; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_a, + arg_b, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + using arg_type = int32x8_t; + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + OpselA, // OPSEL + scale_a, + OpselB, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x128f8f6f4; + +/// @brief Performs a matrix fused multiply-accumulate operation on 16x16x128 submatrices for f8f6f4 +/// data types. +/// +/// @note Calls scaled version of the instruction as the original instruction is not supported in +/// the backend. That is the intended use. There is a backend optimization to select the unscaled +/// operation if the scale is 0. +template <> +struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> +{ + template + __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 0, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10 + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + reg_a, + reg_b, + reg_c.template AsType()[Number<0>{}], + 0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 1, // blgp + 0, + 0, + 0, + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 2, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } + + template + __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x6_t arg_a = bit_cast(reg_a); + int32x6_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0}, + reg_c.template AsType()[Number<0>{}], + 3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1} + 3, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16f8f8; + +template <> +struct intrin_mfma_f32_32x32x16f8f8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8f8; + +template <> +struct intrin_mfma_f32_16x16x32f8f8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16bf8bf8; + +template <> +struct intrin_mfma_f32_32x32x16bf8bf8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8bf8; + +template <> +struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16f8bf8; + +template <> +struct intrin_mfma_f32_32x32x16f8bf8<32, 32> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32f8bf8; + +template <> +struct intrin_mfma_f32_16x16x32f8bf8<16, 16> +{ + template + __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x16bf8f8; + +template <> +struct intrin_mfma_f32_32x32x16bf8f8<32, 32> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +template +struct intrin_mfma_f32_16x16x32bf8f8; + +template <> +struct intrin_mfma_f32_16x16x32bf8f8<16, 16> +{ + template + __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8( + bit_cast(reg_a), + bit_cast(reg_b), + reg_c.template AsType()[Number<0>{}], + 0, + 0, + 0); +#else + vector_type reg_a_v(reg_a); + vector_type reg_b_v(reg_b); + + static_for<0, 8, 1>{}([&](auto k) { + float reg_a_f32 = type_convert(reg_a_v.template AsType()[Number{}]); + float reg_b_f32 = type_convert(reg_b_v.template AsType()[Number{}]); + + intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c); + }); +#endif + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/array.hpp b/csrc/rocm/ck_extensions/include/ck/utility/array.hpp new file mode 100755 index 000000000000..2afad00d497f --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/array.hpp @@ -0,0 +1,70 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_ARRAY_HPP +#define CK_ARRAY_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +template +struct Array +{ + using type = Array; + using data_type = TData; + + TData mData[NSize]; + + __host__ __device__ static constexpr index_t Size() { return NSize; } + + __host__ __device__ constexpr const TData& At(index_t i) const { return mData[i]; } + + __host__ __device__ constexpr TData& At(index_t i) { return mData[i]; } + + __host__ __device__ constexpr const TData& operator[](index_t i) const { return At(i); } + + __host__ __device__ constexpr TData& operator()(index_t i) { return At(i); } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + __host__ __device__ constexpr const TData* begin() const { return &mData[0]; } + __host__ __device__ constexpr const TData* end() const { return &mData[NSize]; } + __host__ __device__ constexpr TData* begin() { return &mData[0]; } + __host__ __device__ constexpr TData* end() { return &mData[NSize]; } +}; + +// empty Array +template +struct Array +{ + using type = Array; + using data_type = TData; + + __host__ __device__ static constexpr index_t Size() { return 0; } +}; + +template +__host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) +{ + using data_type = remove_cvref_t; + return Array{ck::forward(x), ck::forward(xs)...}; +} + +// make empty array +template +__host__ __device__ constexpr auto make_array() +{ + return Array{}; +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/array_multi_index.hpp b/csrc/rocm/ck_extensions/include/ck/utility/array_multi_index.hpp new file mode 100755 index 000000000000..c0c1ea65fc77 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/array_multi_index.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_ARRAY_MULTI_INDEX_HPP +#define CK_ARRAY_MULTI_INDEX_HPP + +#include "common_header.hpp" + +namespace ck { + +template +using MultiIndex = Array; + +template +__host__ __device__ constexpr auto make_multi_index(Xs&&... xs) +{ + return make_array(index_t{xs}...); +} + +template +__host__ __device__ constexpr auto make_zero_multi_index() +{ + return unpack([](auto... xs) { return make_multi_index(xs...); }, + typename uniform_sequence_gen::type{}); +} + +template +__host__ __device__ constexpr auto to_multi_index(const T& x) +{ + return unpack([](auto... ys) { return make_multi_index(ys...); }, x); +} + +template +__host__ __device__ constexpr auto operator+=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator-=(MultiIndex& y, const X& x) +{ + static_assert(X::Size() == NSize, "wrong! size not the same"); + static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; }); + return y; +} + +template +__host__ __device__ constexpr auto operator+(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] + b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator-(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] - b[i]; }); + return r; +} + +template +__host__ __device__ constexpr auto operator*(const MultiIndex& a, const T& b) +{ + using type = MultiIndex; + static_assert(T::Size() == NSize, "wrong! size not the same"); + type r; + static_for<0, NSize, 1>{}([&](auto i) { r(i) = a[i] * b[i]; }); + return r; +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/blkgemmpipe_scheduler.hpp b/csrc/rocm/ck_extensions/include/ck/utility/blkgemmpipe_scheduler.hpp new file mode 100755 index 000000000000..861b81b1f673 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/blkgemmpipe_scheduler.hpp @@ -0,0 +1,149 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +enum struct BlockGemmPipelineVersion +{ + // For GEMM + v1, // Naive + v2, // Mem + v3, // Comp + v4, // Comp, double lds buffer + v5, // Comp, double global prefetch register buffer + + // For GEMM with preshuffled weight + // v1, single lds buffer + // v2, double lds buffer +}; +enum struct BlockGemmPipelineScheduler +{ + Intrawave, + Interwave, +}; + +enum struct TailNumber +{ + // Single / Double buffer pipeline + Odd, + Even, + + // Long prefetch pipeline, up to 8 + One, + Two, + Three, + Four, + Five, + Six, + Seven, + + // Unroll stages > Prefetch stages, number of loop is multiple of unroll stages + Empty, + // Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add + // prefetchstages + Full, +}; + +enum SchedulerGroup : uint32_t +{ + SCHED_GROUP_MFMA = 0x008, // Matrix FMA instructions + SCHED_GROUP_VMEM = 0x020, // Global memory operations + SCHED_GROUP_LDS_READ = 0x100, // LDS read operations + SCHED_GROUP_LDS_WRITE = 0x200 // LDS write operations +}; + +template +struct BlockwiseGemmXdlops_pipeline_hotloop_inst +{ + static constexpr index_t WaveSize = 64; + static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL); + + static constexpr index_t A_LDS_Read_Width = ALDSReadWidth; + static constexpr index_t B_LDS_Read_Width = BLDSReadWidth; + + static constexpr index_t A_Buffer_Load_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth); + static constexpr index_t B_Buffer_Load_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth); + + static constexpr index_t A_LDS_Write_Inst_Num = + MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth); + static constexpr index_t B_LDS_Write_Inst_Num = + NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth); + + static constexpr index_t A_LDS_Read_Inst_Num = + WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth); + static constexpr index_t B_LDS_Read_Inst_Num = + WaveNumM * NPerBlock * KPerBlock / (BlockSize * BLDSReadWidth); + + static constexpr index_t C_MFMA_Inst_Num = + MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL); + + static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1; + + static constexpr index_t C_MFMA_Inst_Cycle = []() { + if constexpr(NPerXDL == 16) + { + return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp; + } + else if constexpr(NPerXDL == 32) + { + return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp; + } + }(); + + static constexpr auto Print() + { + printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n", + BlockSize, + WaveSize, + MPerBlock, + NPerBlock, + KPerBlock, + MPerXDL, + NPerXDL, + KPerXDL); + + printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " + "%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n" + "A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: " + "%d/ %d\n", + A_Buffer_Load_Inst_Num, + B_Buffer_Load_Inst_Num, + A_LDS_Write_Inst_Num, + B_LDS_Write_Inst_Num, + A_LDS_Read_Inst_Num, + B_LDS_Read_Inst_Num, + C_MFMA_Inst_Num, + C_MFMA_Inst_Cycle, + A_LDS_Read_Width, + B_LDS_Read_Width, + ALDSWriteWidth, + BLDSWriteWidth, + ABufferLoadWidth, + BBufferLoadWidth); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/c_style_pointer_cast.hpp b/csrc/rocm/ck_extensions/include/ck/utility/c_style_pointer_cast.hpp new file mode 100755 index 000000000000..610e393a7721 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/c_style_pointer_cast.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_C_STYLE_POINTER_CAST_HPP +#define CK_C_STYLE_POINTER_CAST_HPP + +#include "type.hpp" +#include "enable_if.hpp" + +namespace ck { + +template && is_pointer_v, bool>::type = false> +__host__ __device__ PY c_style_pointer_cast(PX p_x) +{ +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wold-style-cast" +#pragma clang diagnostic ignored "-Wcast-align" + return (PY)p_x; // NOLINT(old-style-cast, cast-align) +#pragma clang diagnostic pop +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/common_header.hpp b/csrc/rocm/ck_extensions/include/ck/utility/common_header.hpp new file mode 100755 index 000000000000..c2c3aa002c13 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/common_header.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/array.hpp" +#include "ck/utility/container_helper.hpp" +#include "ck/utility/statically_indexed_array.hpp" +#include "ck/utility/container_element_picker.hpp" +#include "ck/utility/multi_index.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/functional3.hpp" +#include "ck/utility/functional4.hpp" +#include "ck/utility/enable_if.hpp" +#include "ck/utility/ignore.hpp" +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/math.hpp" +#include "ck/utility/number.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/sequence_helper.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/tuple_helper.hpp" +#include "ck/utility/type.hpp" +#include "ck/utility/type_convert.hpp" +#include "ck/utility/magic_division.hpp" +#include "ck/utility/c_style_pointer_cast.hpp" +#include "ck/utility/is_known_at_compile_time.hpp" +#include "ck/utility/transpose_vectors.hpp" +#include "ck/utility/inner_product.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/debug.hpp" + +#if __clang_major__ == 20 +#include "amd_buffer_addressing_builtins.hpp" +#else +#include "amd_buffer_addressing.hpp" +#endif +#include "ck/utility/amd_wave_read_first_lane.hpp" +#include "ck/utility/generic_memory_space_atomic.hpp" +#include "ck/utility/get_id.hpp" +#include "ck/utility/thread_group.hpp" +#include "ck/utility/synchronization.hpp" +#include "ck/utility/amd_address_space.hpp" +#include "ck/utility/static_buffer.hpp" +#include "ck/utility/dynamic_buffer.hpp" + +// TODO: remove this +#if CK_USE_AMD_INLINE_ASM +#include "ck/utility/amd_inline_asm.hpp" +#endif + +#ifdef CK_USE_AMD_MFMA +#include "ck/utility/amd_xdlops.hpp" +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/container_element_picker.hpp b/csrc/rocm/ck_extensions/include/ck/utility/container_element_picker.hpp new file mode 100755 index 000000000000..838147e420cc --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/container_element_picker.hpp @@ -0,0 +1,158 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_CONTAINER_ELEMENT_PICKER_HPP +#define CK_CONTAINER_ELEMENT_PICKER_HPP + +#include "functional2.hpp" +#include "sequence.hpp" + +namespace ck { + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ContainerElementPicker +{ + using type = ContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ContainerElementPicker() = delete; + + __host__ __device__ constexpr ContainerElementPicker(Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr auto& At(Number i) + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray(IP); + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return At(i); + } + + template + __host__ __device__ constexpr auto operator=(const T& a) + { + static_assert(T::Size() == Size(), "wrong! size not the same"); + + static_for<0, Size(), 1>{}([&](auto i) { operator()(i) = a[i]; }); + + return *this; + } + + private: + Arr& mArray; +}; + +// Arr: Array or StaticallyIndexedArray +// Picks: Sequence<...> +template +struct ConstantContainerElementPicker +{ + using type = ConstantContainerElementPicker; +#if 0 + using data_type = typename Arr::data_type; +#endif + + __host__ __device__ constexpr ConstantContainerElementPicker() = delete; + + __host__ __device__ constexpr ConstantContainerElementPicker(const Arr& array) : mArray{array} + { + constexpr index_t imax = + reduce_on_sequence(Picks{}, math::maximize{}, Number<0>{}); + + static_assert(imax < Arr::Size(), "wrong! exceeding # array element"); + } + + __host__ __device__ static constexpr auto Size() { return Picks::Size(); } + + template + __host__ __device__ constexpr const auto& At(Number i) const + { + static_assert(I < Size(), "wrong!"); + + constexpr auto IP = Picks{}[i]; + return mArray[IP]; + } + + template + __host__ __device__ constexpr const auto& operator[](Number i) const + { + return At(i); + } + + private: + const Arr& mArray; +}; + +template +__host__ __device__ constexpr auto operator+=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) += x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto operator-=(ContainerElementPicker& y, const X& x) +{ + using Y = ContainerElementPicker; + constexpr index_t nsize = Y::Size(); + + static_assert(nsize == X::Size(), "wrong! size not the same"); + + static_for<0, nsize, 1>{}([&](auto i) { y(i) -= x[i]; }); + + return y; +} + +template +__host__ __device__ constexpr auto pick_container_element(Arr& a, Picks) +{ + return ContainerElementPicker(a); +} + +template +__host__ __device__ constexpr auto pick_container_element(const Arr& a, Picks) +{ + return ConstantContainerElementPicker(a); +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/container_helper.hpp b/csrc/rocm/ck_extensions/include/ck/utility/container_helper.hpp new file mode 100755 index 000000000000..bd0ca42ecddb --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/container_helper.hpp @@ -0,0 +1,393 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_CONTAINER_HELPER_HPP +#define CK_CONTAINER_HELPER_HPP + +#include "sequence.hpp" +#include "sequence_helper.hpp" +#include "array.hpp" +#include "tuple.hpp" +#include "tuple_helper.hpp" +#include "statically_indexed_array.hpp" +#include "container_element_picker.hpp" + +namespace ck { + +template +__host__ __device__ constexpr auto container_push_back(const Array& a, const TData& x) +{ + Array r; + + static_for<0, NSize, 1>{}([&r, &a ](auto i) constexpr { r(i) = a[i]; }); + + r(Number{}) = x; + + return r; +} + +template +__host__ __device__ constexpr auto container_push_front(const Tuple& a, const T& x) +{ + return container_concat(make_tuple(x), a); +} + +template +__host__ __device__ constexpr auto container_push_back(const Tuple& a, const T& x) +{ + return container_concat(a, make_tuple(x)); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_new2old(const Array& old_array, Sequence /*new2old*/) +{ + static_assert(NSize == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_array(old_array[Number{}]...); +} + +template +__host__ __device__ constexpr auto +container_reorder_given_old2new(const Array& old_array, Sequence old2new) +{ + return container_reorder_given_new2old( + old_array, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(const Tuple& old_tuple, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Ts) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return make_tuple(old_tuple[Number{}]...); +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(const Tuple& old_tuple, + Sequence old2new) +{ + return container_reorder_given_new2old( + old_tuple, typename sequence_map_inverse::type{}); +} + +template +__host__ __device__ constexpr auto container_reorder_given_new2old(Sequence /* old_seq */, + Sequence /*new2old*/) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + return Sequence::At(Number{})...>{}; +} + +template +__host__ __device__ constexpr auto container_reorder_given_old2new(Sequence old_seq, + Sequence /* old2new */) +{ + static_assert(sizeof...(Is) == sizeof...(IRs), "wrong! size not consistent"); + + static_assert(is_valid_sequence_map>{}, "wrong! invalid reorder map"); + + constexpr auto new2old = typename sequence_map_inverse>::type{}; + + return container_reorder_given_new2old(old_seq, new2old); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm-4.1 compiler would crash for recursive lambda +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto r_old) { + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + // recursively call f/fs + return fs(fs, i + Number{}, r_new); + } + else + { + return r_new; + } + }; + + // start recursion + return f(f, Number{}, init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reduce_impl( + const Container& x, Reduce reduce, ROld r_old, Number i, Number, Number) +{ + auto r_new = reduce(x[i], r_old); + + if constexpr(i.value < IEnd - IStep) + { + return container_reduce_impl( + x, reduce, r_new, i + Number{}, Number{}, Number{}); + } + else + { + return r_new; + } +} + +// rocm-4.1 compiler would crash for recursive lambda +// container reduce with initial value +template +__host__ __device__ constexpr auto container_reduce(const Container& x, + Reduce reduce, + Init init, + Number = Number<0>{}, + Number = Number{}, + Number = Number<1>{}) +{ + static_assert((IEnd - IBegin) % IStep == 0, "wrong!"); + + if constexpr(IEnd > IBegin) + { + return container_reduce_impl( + x, reduce, init, Number{}, Number{}, Number{}); + } + else + { + return init; + } +} +#endif + +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Array& x, Reduce f, TData init) +{ + Array y; + + TData r = init; + + static_for{}([&](auto i) { + y(i) = r; + r = f(r, x[i]); + }); + + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Sequence& seq, Reduce f, Number) +{ + return reverse_exclusive_scan_sequence(seq, f, Number{}); +} + +#if !CK_WORKAROUND_SWDEV_275126 +// rocm4.1 compiler would crash with recursive lambda +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + // f is recursive function, fs is a dummy of f + // i is index, y_old is current scan, r_old is current reduction + auto f = [&](auto fs, auto i, auto y_old, auto r_old) { + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return fs(fs, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } + }; + + // start recursion + return f(f, Number{}, make_tuple(init), init); +} +#else +// i is index, y_old is current scan, r_old is current reduction +template +__host__ __device__ constexpr auto container_reverse_exclusive_scan_impl( + const Tuple& x, Reduce reduce, Number i, YOld y_old, ROld r_old) +{ + auto r_new = reduce(x[i], r_old); + + auto y_new = container_push_front(y_old, r_new); + + if constexpr(i.value > 1) + { + // recursively call f/fs + return container_reverse_exclusive_scan_impl(x, reduce, i - Number<1>{}, y_new, r_new); + } + else + { + return y_new; + } +} + +template +__host__ __device__ constexpr auto +container_reverse_exclusive_scan(const Tuple& x, Reduce reduce, Init init) +{ + constexpr index_t NSize = sizeof...(Xs); + + return container_reverse_exclusive_scan_impl( + x, reduce, Number{}, make_tuple(init), init); +} +#endif + +// TODO: update to like container_reverse_exclusive_scan to deal with Tuple of Numebr<> +template +__host__ __device__ constexpr auto +container_reverse_inclusive_scan(const Tuple& x, Reduce f, TData init) +{ + constexpr index_t NSize = sizeof...(Xs); + + Tuple y; + + TData r = init; + + static_for{}([&](auto i) { + r = f(r, x[i]); + y(i) = r; + }); + + r = f(r, x[Number<0>{}]); + y(Number<0>{}) = r; + + return y; +} + +template +__host__ __device__ constexpr auto container_concat(const X& x, const Ys&... ys) +{ + return container_concat(x, container_concat(ys...)); +} + +template +__host__ __device__ constexpr auto container_concat(const Array& ax, const Array& ay) +{ + return unpack2( + [&](auto&&... zs) { return make_array(ck::forward(zs)...); }, ax, ay); +} + +template +__host__ __device__ constexpr auto container_concat(const Tuple& tx, const Tuple& ty) +{ + return unpack2( + [&](auto&&... zs) { return make_tuple(ck::forward(zs)...); }, tx, ty); +} + +template +__host__ __device__ constexpr auto container_concat(const Container& x) +{ + return x; +} + +template +__host__ __device__ constexpr auto get_container_subset(const Array& arr, Sequence) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + return make_array(arr[Number{}]...); +} + +template +__host__ __device__ constexpr auto get_container_subset(const Tuple& tup, Sequence) +{ + static_assert(sizeof...(Ts) >= sizeof...(Is), "wrong! size"); + + return make_tuple(tup[Number{}]...); +} + +template +__host__ __device__ constexpr void +set_container_subset(Array& y, Sequence picks, const Array& x) +{ + static_assert(N >= sizeof...(Is), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr void +set_container_subset(Tuple& y, Sequence picks, const Tuple& x) +{ + static_assert(sizeof...(Ys) >= sizeof...(Is) && sizeof...(Is) == sizeof...(Xs), "wrong! size"); + + static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; }); +} + +template +__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence) +{ + using Seq = Sequence; + + return generate_tuple( + [&](auto i) { + constexpr index_t tmp = Seq::At(i); + return Number{}; + }, + Seq::Size()); +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/data_type.hpp b/csrc/rocm/ck_extensions/include/ck/utility/data_type.hpp new file mode 100755 index 000000000000..5fbe30d21ba4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/data_type.hpp @@ -0,0 +1,464 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/amd_ck_fp8.hpp" +#include "ck/utility/e8m0.hpp" +#include "ck/utility/statically_indexed_array.hpp" + +/// Definitions from , conflict with +/// /opt/rocm/include/hip/amd_detail/amd_hip_vector_types.h. + +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +#define CHAR_BIT 8 +using int8_t = signed char; +using uint8_t = unsigned char; +using int16_t = signed short; +using uint16_t = unsigned short; +using float_t = float; +#endif // __HIPCC_RTC__ + +namespace ck { +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +using byte = unsigned char; +#else +using std::byte; +#endif + +using bhalf_t = ushort; +using half_t = _Float16; +using int4_t = _BitInt(4); +using f4_t = unsigned _BitInt(4); +using f6_t = _BitInt(6); // e2m3 format +using bf6_t = unsigned _BitInt(6); // e3m2 format + +// scalar_type +template +struct scalar_type; + +struct f4x2_pk_t +{ + static constexpr int packed_size = 2; + + using type = uint8_t; + type data; + __host__ __device__ constexpr f4x2_pk_t() : data{type{}} {} + __host__ __device__ constexpr f4x2_pk_t(const type init) : data{init} {} + + template + __host__ __device__ inline type unpack(Number) const + { + static_assert(I < 2, "Index is out of range."); + if constexpr(I == 1) + return (data >> 4); + else + return data & 0b00001111; + } + + __host__ __device__ inline type pack(const type x0, const type x1) + { + return (x1 << 4) | (x0 & 0b00001111); + } + + // Compare operator + __host__ __device__ friend bool operator==(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs) + { + return lhs.data == rhs.data; + } + + __host__ __device__ friend bool operator!=(const f4x2_pk_t& lhs, const f4x2_pk_t& rhs) + { + return !(lhs == rhs); + } +}; + +template +struct f6_pk_t +{ + using element_type = uint32_t; // element storage fundamental type + + static constexpr index_t packed_size = pk_size; // 16 or 32 for now + static constexpr index_t num_bits_elem = 6; // specialized for 6-bit data + // XXX: CHAR_BIT is not defined in HIPRTC, so we must use 8 + static constexpr index_t num_bits_vec_elem = + sizeof(element_type) * 8; // 32-bit uint for storage + static_assert((packed_size * num_bits_elem) % num_bits_vec_elem == 0, + "Packed elements must fit exactly into the element storage."); + static constexpr index_t vector_size = + (packed_size * num_bits_elem) / num_bits_vec_elem; // 3 or 6 element_type units + + using storage_type = element_type __attribute__((ext_vector_type(vector_size))); + storage_type data_{storage_type(0)}; // packed data + + using type = f6_pk_t; + + __host__ __device__ constexpr f6_pk_t() {} + __host__ __device__ constexpr f6_pk_t(const storage_type& init) : data_{init} + { + // TODO: consider removing initialization similar to vector_type + } + + // Initialize from a vector type with the same size as packed_size + template ::vector_size == packed_size>> + __host__ __device__ f6_pk_t(const T& v) + { + static_for<0, packed_size, 1>{}( + [&](auto i) { pack(v[static_cast(i)], static_cast(i)); }); + } + + // Broadcast single initialization value to all packed elements + __host__ __device__ f6_pk_t(const int8_t v) + : f6_pk_t(static_cast(v)) + { + // TODO: consider removing initialization similar to vector_type + } + + template + __host__ __device__ void pack(const T x, const index_t i) + { + static_assert(is_integral::value || is_same_v, + "T must be an integral type."); + + uint32_t bits = static_cast(x) & 0x3F; + const int bit_pos = i * num_bits_elem; + const int arr_index = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + uint32_t old_value = data_[arr_index]; + + // insert bits into the current 32-bit block + old_value |= (bits << bit_offset); + data_[arr_index] = old_value; + + // if it crosses into the next block, shift the remainder + if(overhang > 0 && (arr_index + 1) < vector_size) + { + uint32_t next_value = data_[arr_index + 1]; + next_value |= (bits >> (num_bits_elem - overhang)); + data_[arr_index + 1] = next_value; + } + } + + __host__ __device__ static inline BitType unpack(const type& pk, const index_t i) + { + const int bit_pos = i * num_bits_elem; + const int arr_idx = bit_pos / num_bits_vec_elem; + const int bit_offset = bit_pos % num_bits_vec_elem; + const int overhang = bit_offset + num_bits_elem - num_bits_vec_elem; + + uint32_t bits = pk.data_[arr_idx] >> bit_offset; + if(overhang > 0 && (arr_idx + 1) < vector_size) + { + bits |= (pk.data_[arr_idx + 1] & ((1u << overhang) - 1)) << (num_bits_elem - overhang); + } + + return static_cast(bits & 0x3F); + } + + __host__ __device__ inline BitType unpack(const index_t i) const { return unpack(*this, i); } + + // Compare operator + __host__ __device__ friend bool operator==(const f6_pk_t& lhs, const f6_pk_t& rhs) + { +#pragma unroll + for(index_t i = 0; i < vector_size; ++i) + { + if(lhs.data_[i] != rhs.data_[i]) + return false; + } + return true; + } + + __host__ __device__ friend bool operator!=(const f6_pk_t& lhs, const f6_pk_t& rhs) + { + return !(lhs == rhs); + } +}; + +using f6x16_pk_t = f6_pk_t; +using f6x32_pk_t = f6_pk_t; +using bf6x16_pk_t = f6_pk_t; +using bf6x32_pk_t = f6_pk_t; + +// custom data type - pack int4 data +struct pk_i4_t +{ + using type = int8_t; + type data; + __host__ __device__ constexpr pk_i4_t() : data{type{}} {} + __host__ __device__ constexpr pk_i4_t(type init) : data{init} {} +}; + +inline constexpr auto next_pow2(uint32_t x) +{ + // Precondition: x > 1. + return x > 1u ? (1u << (32u - __builtin_clz(x - 1u))) : x; +} + +// native types: double, float, _Float16, ushort, int32_t, int8_t, uint8_t, f8_fnuz_t, bf8_fnuz_t, +// native types: bool +template +inline constexpr bool is_native_type() +{ + return is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value; +} + +// scalar_type +template +struct scalar_type; + +// is_scalar_type +template +struct is_scalar_type +{ + static constexpr bool value = (scalar_type>::vector_size == 1); +}; + +// has_same_scalar_type +template +using has_same_scalar_type = is_same>::type, + typename scalar_type>::type>; + +template +struct scalar_type +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +// +template <> +struct scalar_type +{ + using type = double; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = float; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = half_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bhalf_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int32_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = int8_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = uint8_t; + static constexpr index_t vector_size = 1; +}; + +#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 +template <> +struct scalar_type +{ + using type = int4_t; + static constexpr index_t vector_size = 1; +}; +#endif + +template <> +struct scalar_type +{ + using type = pk_i4_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f8_fnuz_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf8_fnuz_t; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f8_ocp_t::data_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf8_ocp_t::data_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = e8m0_bexp_t::type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f4x2_pk_t::type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f6x32_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf6x32_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = f6x16_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bf6x16_pk_t::storage_type; + static constexpr index_t vector_size = 1; +}; + +template <> +struct scalar_type +{ + using type = bool; + static constexpr index_t vector_size = 1; +}; + +template +struct packed_type_info +{ + private: + static constexpr auto get_packed_type_info() + { + using U = remove_cvref_t; + if constexpr(is_same_v) + return ck::Tuple, int4_t>{}; + else if constexpr(is_same_v) + return ck::Tuple, f4_t>{}; + else if constexpr(is_same_v) + return ck::Tuple, f6_t>{}; + else if constexpr(is_same_v) + return ck::Tuple, bf6_t>{}; + else if constexpr(is_same_v) + return ck::Tuple, f6_t>{}; + else if constexpr(is_same_v) + return ck::Tuple, bf6_t>{}; + else + return ck::Tuple, T>{}; + } + + public: + using element_type = remove_cvref_t{}))>; + static constexpr auto packed_size = + static_cast(get_packed_type_info().At(ck::Number<0>{})); +}; +template +using element_type_t = typename packed_type_info::element_type; + +template +inline constexpr index_t packed_size_v = packed_type_info::packed_size; + +template +inline constexpr bool is_packed_type_v = packed_size_v > 1; + +template +struct packed_type_maker +{ + private: + static constexpr auto get_packed_type() + { + using U = remove_cvref_t; + if constexpr(is_same_v) + { + static_assert(N == 0 || N == 2, "Packed size N for int4_t must be 2."); + return pk_i4_t{}; + } + else if constexpr(is_same_v) + { + static_assert(N == 0 || N == 2, "Packed size N for f4_t must be 2."); + return f4x2_pk_t{}; + } + else if constexpr(is_same_v) + { + static_assert(N == 0 || N == 16 || N == 32, "Packed size N for f6_t must be 16 or 32."); + if constexpr(N == 16) + return f6x16_pk_t{}; + else if constexpr(N == 0 || N == 32) + return f6x32_pk_t{}; + } + else if constexpr(is_same_v) + { + static_assert(N == 0 || N == 16 || N == 32, + "Packed size N for bf6_t must be 16 or 32."); + if constexpr(N == 16) + return bf6x16_pk_t{}; + else if constexpr(N == 0 || N == 32) + return bf6x32_pk_t{}; + } + else + return T{}; + } + + public: + using packed_type = remove_cvref_t; +}; + +template +using packed_type_t = typename packed_type_maker::packed_type; + +#if defined(_WIN32) +using int64_t = long long; +#else +using int64_t = long; +#endif + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/debug.hpp b/csrc/rocm/ck_extensions/include/ck/utility/debug.hpp new file mode 100755 index 000000000000..45d443ae495b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/debug.hpp @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef UTILITY_DEBUG_HPP +#define UTILITY_DEBUG_HPP +#include "type.hpp" + +namespace ck { +namespace debug { + +namespace detail { +template +struct PrintAsType; + +template +struct PrintAsType::value>::type> +{ + using type = float; + __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast(p)); } +}; + +template <> +struct PrintAsType +{ + using type = float; + __host__ __device__ static void Print(const ck::half_t& p) + { + printf("%.3f ", static_cast(p)); + } +}; + +template +struct PrintAsType::value>::type> +{ + using type = int; + __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast(p)); } +}; +} // namespace detail + +// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer +// and the number of elements. Can optionally specify strides between elements and how many bytes' +// worth of data per row. +// +// Usage example: +// +// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize())); +// +template +__device__ void print_shared(T const* p_shared, index_t num_elements) +{ + constexpr index_t row_elements = row_bytes / sizeof(T); + static_assert((element_stride >= 1 && element_stride <= row_elements), + "element_stride should between [1, row_elements]"); + + index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + index_t tid = + (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x; + + __syncthreads(); + + if(tid == 0) + { + printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n", + wgid, + row_bytes, + element_stride); + for(index_t i = 0; i < num_elements; i += row_elements) + { + printf("elem %5d: ", i); + for(index_t j = 0; j < row_elements; j += element_stride) + { + detail::PrintAsType::Print(p_shared[i + j]); + } + + printf("\n"); + } + printf("\n"); + } + + __syncthreads(); +} + +template +__device__ static bool is_thread_local_1d_id_idx() +{ + const auto tid = get_thread_local_1d_id(); + return ((tid == Ids) || ...); +} + +// Use `CK_PRINT()` to inspect values of type T1, T2, ... +// Use `CK_PRINT()` to inspect constexpr values of val1, val2, ... of the same type +// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());` +// Set BUILD_DEV to OFF to avoid enabling Werror +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} +template +[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT() +{ +} + +} // namespace debug +} // namespace ck + +#endif // UTILITY_DEBUG_HPP diff --git a/csrc/rocm/ck_extensions/include/ck/utility/dtype_fp64.hpp b/csrc/rocm/ck_extensions/include/ck/utility/dtype_fp64.hpp new file mode 100755 index 000000000000..3c63d083ade4 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/dtype_fp64.hpp @@ -0,0 +1,7 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +namespace ck { +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/dtype_vector.hpp b/csrc/rocm/ck_extensions/include/ck/utility/dtype_vector.hpp new file mode 100755 index 000000000000..ae0edb35ee31 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/dtype_vector.hpp @@ -0,0 +1,2269 @@ +// SPDX-License-Identifier: MIT +// // // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once +#include "ck/utility/data_type.hpp" + +namespace ck { + +// vector_type +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type; + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of +// vectors" +template +struct vector_type, N>; + +// vector_type_maker +// This is the right way to handle "vector of vectors": making a bigger vector instead +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct scalar_type> +{ + using type = T; + static constexpr index_t vector_size = N; +}; + +template +struct vector_type_maker +{ + using type = vector_type; +}; + +template +struct vector_type_maker, N0> +{ + using type = vector_type; +}; + +template +using vector_type_maker_t = typename vector_type_maker::type; + +template +__host__ __device__ constexpr auto make_vector_type(Number) +{ + return typename vector_type_maker::type{}; +} + +template +struct vector_type()>> +{ + using d1_t = T; + using type = d1_t; + + union + { + T d1_; + StaticallyIndexedArray d1x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value, + "Something went wrong, please check src and dst types."); + + return data_.d1x1_; + } +}; + +__device__ int static err = 0; +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + + using type = d2_t; + + union + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d3_t __attribute__((ext_vector_type(3))); + + using type = d3_t; + + union + { + d3_t d3_; + StaticallyIndexedArray d1x3_; + StaticallyIndexedArray d2x1_; + StaticallyIndexedArray d3x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x3_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else if constexpr(is_same::value) + { + return data_.d3x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x3_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else if constexpr(is_same::value) + { + return data_.d3x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + + using type = d4_t; + + union + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d5_t __attribute__((ext_vector_type(5))); + + using type = d5_t; + + union + { + d5_t d5_; + StaticallyIndexedArray d1x5_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d5x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x5_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d5x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x5_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d5x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d3_t __attribute__((ext_vector_type(3))); + typedef T d6_t __attribute__((ext_vector_type(6))); + + using type = d6_t; + + union + { + d6_t d6_; + StaticallyIndexedArray d1x6_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d3x2_; + StaticallyIndexedArray d6x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x6_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d3x2_; + } + else if constexpr(is_same::value) + { + return data_.d6x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d7_t __attribute__((ext_vector_type(7))); + + using type = d7_t; + + union + { + d7_t d7_; + StaticallyIndexedArray d1x7_; + StaticallyIndexedArray d2x3_; + StaticallyIndexedArray d4x1_; + StaticallyIndexedArray d7x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x7_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d7x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x7_; + } + else if constexpr(is_same::value) + { + return data_.d2x3_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else if constexpr(is_same::value) + { + return data_.d7x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + + using type = d8_t; + + union + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d13_t __attribute__((ext_vector_type(13))); + + using type = d13_t; + + union + { + d13_t d13_; + StaticallyIndexedArray d1x13_; + StaticallyIndexedArray d4x3_; + StaticallyIndexedArray d8x1_; + StaticallyIndexedArray d13x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x13_; + } + else if constexpr(is_same::value) + { + return data_.d4x3_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else if constexpr(is_same::value) + { + return data_.d13x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + + using type = d16_t; + + union + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + + using type = d32_t; + + union + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_ = {d32_t{0}}; + + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} + + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } + + // __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + // __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + + using type = d64_t; + + union + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + + using type = d128_t; + + union + { + d128_t d128_; + StaticallyIndexedArray d1x128_; + StaticallyIndexedArray d2x64_; + StaticallyIndexedArray d4x32_; + StaticallyIndexedArray d8x16_; + StaticallyIndexedArray d16x8_; + StaticallyIndexedArray d32x4_; + StaticallyIndexedArray d64x2_; + StaticallyIndexedArray d128x1_; + } data_ = {d128_t{0}}; + + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} + + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x128_; + } + else if constexpr(is_same::value) + { + return data_.d2x64_; + } + else if constexpr(is_same::value) + { + return data_.d4x32_; + } + else if constexpr(is_same::value) + { + return data_.d8x16_; + } + else if constexpr(is_same::value) + { + return data_.d16x8_; + } + else if constexpr(is_same::value) + { + return data_.d32x4_; + } + else if constexpr(is_same::value) + { + return data_.d64x2_; + } + else if constexpr(is_same::value) + { + return data_.d128x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + typedef T d2_t __attribute__((ext_vector_type(2))); + typedef T d4_t __attribute__((ext_vector_type(4))); + typedef T d8_t __attribute__((ext_vector_type(8))); + typedef T d16_t __attribute__((ext_vector_type(16))); + typedef T d32_t __attribute__((ext_vector_type(32))); + typedef T d64_t __attribute__((ext_vector_type(64))); + typedef T d128_t __attribute__((ext_vector_type(128))); + typedef T d256_t __attribute__((ext_vector_type(256))); + + using type = d256_t; + + union + { + d256_t d256_; + StaticallyIndexedArray d1x256_; + StaticallyIndexedArray d2x128_; + StaticallyIndexedArray d4x64_; + StaticallyIndexedArray d8x32_; + StaticallyIndexedArray d16x16_; + StaticallyIndexedArray d32x8_; + StaticallyIndexedArray d64x4_; + StaticallyIndexedArray d128x2_; + StaticallyIndexedArray d256x1_; + } data_ = {d256_t{0}}; + + __attribute__((host)) __attribute__((device)) constexpr vector_type() {} + + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { (void)v; } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert( + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value || + is_same::value || is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x256_; + } + else if constexpr(is_same::value) + { + return data_.d2x128_; + } + else if constexpr(is_same::value) + { + return data_.d4x64_; + } + else if constexpr(is_same::value) + { + return data_.d8x32_; + } + else if constexpr(is_same::value) + { + return data_.d16x16_; + } + else if constexpr(is_same::value) + { + return data_.d32x8_; + } + else if constexpr(is_same::value) + { + return data_.d64x4_; + } + else if constexpr(is_same::value) + { + return data_.d128x2_; + } + else if constexpr(is_same::value) + { + return data_.d256x1_; + } + else + { + return err; + } + } +}; + +template +struct non_native_vector_base; + +template +struct nnvb_data_t_selector +{ + using type = unsigned _BitInt(8 * sizeof(T)); +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf8_ocp_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = e8m0_bexp_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x16_pk_t::storage_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f6x32_pk_t::storage_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x16_pk_t::storage_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf6x32_pk_t::storage_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = pk_i4_t::type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = f4x2_pk_t::type; +}; + +template +struct non_native_vector_base< + T, + N, + ck::enable_if_t> +{ + using data_t = typename nnvb_data_t_selector::type; // select data_t based on the size of T + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + using data_v = data_t __attribute__((ext_vector_type(N))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{data_v(a)} {} + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + __host__ __device__ constexpr operator data_t() const + { + if constexpr(N == 1) + { + return data_.dxN[Number<0>{}]; + } + else + { + return data_.dxN; // XXX this should cause an error + } + } + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return data_.dTxN; // XXX this should cause an error + } + } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else if constexpr(is_same_v) + { + return data_.dNx1; + } + else + { + return err; + } + } +}; + +// implementation for f6x16 and f6x32 +template +struct non_native_vector_base< + T, + N, + ck::enable_if_t> +{ + using data_t = + typename nnvb_data_t_selector::type; // select data_t based on declared base type + using element_t = typename T::element_type; // select element_t based on declared element type + static_assert(sizeof(T) == sizeof(data_t), "non_native_vector_base storage size mismatch"); + static constexpr size_t size_factor = sizeof(data_t) / sizeof(element_t); + using data_v = element_t __attribute__((ext_vector_type(N * size_factor))); + using type = non_native_vector_base; + + union alignas(next_pow2(N * sizeof(T))) + { + data_v dN; // storage vector; + StaticallyIndexedArray dxN; + StaticallyIndexedArray dTxN; + StaticallyIndexedArray dNx1; + } data_; + + // Broadcast single value to vector + __host__ __device__ constexpr non_native_vector_base(data_t a) : data_{} + { + // TODO: consider removing initialization similar to vector_type + + ck::static_for<0, N, 1>{}([&](auto i) { + data_.dxN(i) = a; // broadcast value to all elements + }); + } + + __host__ __device__ constexpr non_native_vector_base(T f) + : non_native_vector_base(bit_cast(f)) + { + } + + __host__ __device__ constexpr non_native_vector_base() : non_native_vector_base(T{}){}; + + __host__ __device__ constexpr non_native_vector_base(data_v v) : data_{v} {} + + __host__ __device__ constexpr non_native_vector_base(element_t v) : data_{data_v(v)} {} + + __host__ __device__ constexpr operator data_v() const { return data_.dN; } + + __host__ __device__ constexpr operator T() const + { + if constexpr(N == 1) + { + return data_.dTxN[Number<0>{}]; + } + else + { + return err; // XXX this should cause an error + } + } + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same_v || is_same_v || is_same_v, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same_v) + { + return data_.dNx1; + } + else if constexpr(is_same_v) + { + return data_.dxN; + } + else if constexpr(is_same_v) + { + return data_.dTxN; + } + else + { + return err; + } + } +}; + +template +struct scalar_type>> +{ + using type = typename non_native_vector_base::data_t; + static constexpr index_t vector_size = N; +}; + +template +struct scalar_type>> +{ + using type = typename non_native_vector_base::element_t; + static constexpr index_t vector_size = N * non_native_vector_base::size_factor; +}; + +// non-native vector_type implementation +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using type = d1_nnv_t; + + union alignas(next_pow2(1 * sizeof(T))) + { + d1_t d1_; + StaticallyIndexedArray d1x1_; + d1_nnv_t d1_nnv_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{d1_t{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + + using type = d2_t; + + union alignas(next_pow2(2 * sizeof(T))) + { + d2_t d2_; + StaticallyIndexedArray d1x2_; + StaticallyIndexedArray d2x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x2_; + } + else if constexpr(is_same::value) + { + return data_.d2x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + + using type = d4_t; + + union alignas(next_pow2(4 * sizeof(T))) + { + d4_t d4_; + StaticallyIndexedArray d1x4_; + StaticallyIndexedArray d2x2_; + StaticallyIndexedArray d4x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x4_; + } + else if constexpr(is_same::value) + { + return data_.d2x2_; + } + else if constexpr(is_same::value) + { + return data_.d4x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + + using type = d8_t; + + union alignas(next_pow2(8 * sizeof(T))) + { + d8_t d8_; + StaticallyIndexedArray d1x8_; + StaticallyIndexedArray d2x4_; + StaticallyIndexedArray d4x2_; + StaticallyIndexedArray d8x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x8_; + } + else if constexpr(is_same::value) + { + return data_.d2x4_; + } + else if constexpr(is_same::value) + { + return data_.d4x2_; + } + else if constexpr(is_same::value) + { + return data_.d8x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d1_nnv_t = non_native_vector_base; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + + using type = d16_t; + + union alignas(next_pow2(16 * sizeof(T))) + { + d16_t d16_; + StaticallyIndexedArray d1x16_; + StaticallyIndexedArray d2x8_; + StaticallyIndexedArray d4x4_; + StaticallyIndexedArray d8x2_; + StaticallyIndexedArray d16x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value || is_same::value) + { + return data_.d1x16_; + } + else if constexpr(is_same::value) + { + return data_.d2x8_; + } + else if constexpr(is_same::value) + { + return data_.d4x4_; + } + else if constexpr(is_same::value) + { + return data_.d8x2_; + } + else if constexpr(is_same::value) + { + return data_.d16x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + + using type = d32_t; + + union alignas(next_pow2(32 * sizeof(T))) + { + d32_t d32_; + StaticallyIndexedArray d1x32_; + StaticallyIndexedArray d2x16_; + StaticallyIndexedArray d4x8_; + StaticallyIndexedArray d8x4_; + StaticallyIndexedArray d16x2_; + StaticallyIndexedArray d32x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x32_; + } + else if constexpr(is_same::value) + { + return data_.d2x16_; + } + else if constexpr(is_same::value) + { + return data_.d4x8_; + } + else if constexpr(is_same::value) + { + return data_.d8x4_; + } + else if constexpr(is_same::value) + { + return data_.d16x2_; + } + else if constexpr(is_same::value) + { + return data_.d32x1_; + } + else + { + return err; + } + } +}; + +template +struct vector_type()>> +{ + using d1_t = T; + using d2_t = non_native_vector_base; + using d4_t = non_native_vector_base; + using d8_t = non_native_vector_base; + using d16_t = non_native_vector_base; + using d32_t = non_native_vector_base; + using d64_t = non_native_vector_base; + + using type = d64_t; + + union alignas(next_pow2(64 * sizeof(T))) + { + d64_t d64_; + StaticallyIndexedArray d1x64_; + StaticallyIndexedArray d2x32_; + StaticallyIndexedArray d4x16_; + StaticallyIndexedArray d8x8_; + StaticallyIndexedArray d16x4_; + StaticallyIndexedArray d32x2_; + StaticallyIndexedArray d64x1_; + } data_; + + __host__ __device__ constexpr vector_type() : data_{type{}} {} + + __host__ __device__ constexpr vector_type(type v) : data_{v} {} + + template + __host__ __device__ constexpr const auto& AsType() const + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } + + template + __host__ __device__ constexpr auto& AsType() + { + static_assert(is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value, + "Something went wrong, please check src and dst types."); + + if constexpr(is_same::value) + { + return data_.d1x64_; + } + else if constexpr(is_same::value) + { + return data_.d2x32_; + } + else if constexpr(is_same::value) + { + return data_.d4x16_; + } + else if constexpr(is_same::value) + { + return data_.d8x8_; + } + else if constexpr(is_same::value) + { + return data_.d16x4_; + } + else if constexpr(is_same::value) + { + return data_.d32x2_; + } + else if constexpr(is_same::value) + { + return data_.d64x1_; + } + else + { + return err; + } + } +}; + +// fp32 +using float2_t = typename vector_type::type; +using float4_t = typename vector_type::type; +using float8_t = typename vector_type::type; +using float16_t = typename vector_type::type; +using float32_t = typename vector_type::type; +using float64_t = typename vector_type::type; + +// fp16 +using half2_t = typename vector_type::type; +using half4_t = typename vector_type::type; +using half8_t = typename vector_type::type; +using half16_t = typename vector_type::type; +using half32_t = typename vector_type::type; + +// bfp16 +using bhalf2_t = typename vector_type::type; +using bhalf4_t = typename vector_type::type; +using bhalf8_t = typename vector_type::type; +using bhalf16_t = typename vector_type::type; +using bhalf32_t = typename vector_type::type; + +// i32 +using int32x2_t = typename vector_type::type; +using int32x4_t = typename vector_type::type; +using int32x6_t = typename vector_type::type; +using int32x8_t = typename vector_type::type; +using int32x16_t = typename vector_type::type; +using int32x32_t = typename vector_type::type; +using int32x64_t = typename vector_type::type; + +// i8 +using int8x2_t = typename vector_type::type; +using int8x4_t = typename vector_type::type; +using int8x8_t = typename vector_type::type; +using int8x16_t = typename vector_type::type; +using int8x32_t = typename vector_type::type; +using int8x64_t = typename vector_type::type; + +// f8 +using f8x2_fnuz_t = typename vector_type::type; +using f8x4_fnuz_t = typename vector_type::type; +using f8x8_fnuz_t = typename vector_type::type; +using f8x16_fnuz_t = typename vector_type::type; +using f8x32_fnuz_t = typename vector_type::type; +using f8x64_fnuz_t = typename vector_type::type; + +// bf8 +using bf8x2_fnuz_t = typename vector_type::type; +using bf8x4_fnuz_t = typename vector_type::type; +using bf8x8_fnuz_t = typename vector_type::type; +using bf8x16_fnuz_t = typename vector_type::type; +using bf8x32_fnuz_t = typename vector_type::type; +using bf8x64_fnuz_t = typename vector_type::type; + +// f8 +using f8x2_ocp_t = typename vector_type::type; +using f8x4_ocp_t = typename vector_type::type; +using f8x8_ocp_t = typename vector_type::type; +using f8x16_ocp_t = typename vector_type::type; +using f8x32_ocp_t = typename vector_type::type; +using f8x64_ocp_t = typename vector_type::type; + +// bf8 +using bf8x2_ocp_t = typename vector_type::type; +using bf8x4_ocp_t = typename vector_type::type; +using bf8x8_ocp_t = typename vector_type::type; +using bf8x16_ocp_t = typename vector_type::type; +using bf8x32_ocp_t = typename vector_type::type; +using bf8x64_ocp_t = typename vector_type::type; + +#if CK_FP8_TYPE_OCP +// f8 +using f8x2_t = f8x2_ocp_t; +using f8x4_t = f8x4_ocp_t; +using f8x8_t = f8x8_ocp_t; +using f8x16_t = f8x16_ocp_t; +using f8x32_t = f8x32_ocp_t; +using f8x64_t = f8x64_ocp_t; + +// bf8 +using bf8x2_t = bf8x2_ocp_t; +using bf8x4_t = bf8x4_ocp_t; +using bf8x8_t = bf8x8_ocp_t; +using bf8x16_t = bf8x16_ocp_t; +using bf8x32_t = bf8x32_ocp_t; +using bf8x64_t = bf8x64_ocp_t; +#elif CK_FP8_TYPE_FNUZ +// f8 +using f8x2_t = f8x2_fnuz_t; +using f8x4_t = f8x4_fnuz_t; +using f8x8_t = f8x8_fnuz_t; +using f8x16_t = f8x16_fnuz_t; +using f8x32_t = f8x32_fnuz_t; +using f8x64_t = f8x64_fnuz_t; + +// bf8 +using bf8x2_t = bf8x2_fnuz_t; +using bf8x4_t = bf8x4_fnuz_t; +using bf8x8_t = bf8x8_fnuz_t; +using bf8x16_t = bf8x16_fnuz_t; +using bf8x32_t = bf8x32_fnuz_t; +using bf8x64_t = bf8x64_fnuz_t; +#endif + +// u8 +using uint8x2_t = typename vector_type::type; +using uint8x4_t = typename vector_type::type; +using uint8x8_t = typename vector_type::type; +using uint8x16_t = typename vector_type::type; +using uint8x32_t = typename vector_type::type; +using uint8x64_t = typename vector_type::type; + +// f4 +using f4x2_t = typename vector_type::type; +using f4x4_t = typename vector_type::type; +using f4x8_t = typename vector_type::type; +using f4x16_t = typename vector_type::type; +using f4x32_t = typename vector_type::type; +using f4x64_t = typename vector_type::type; + +// f6 +using f6x16_t = typename vector_type::type; +using f6x16x2_t = typename vector_type::type; +using f6x32_t = typename vector_type::type; + +// bf6 +using bf6x16_t = typename vector_type::type; +using bf6x16x2_t = typename vector_type::type; +using bf6x32_t = typename vector_type::type; + +// e8m0 +using e8m0x4_bexp_t = typename vector_type::type; + +// pack int4 +using pk_i4x2_t = typename vector_type::type; +using pk_i4x4_t = typename vector_type::type; +using pk_i4x8_t = typename vector_type::type; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/dynamic_buffer.hpp b/csrc/rocm/ck_extensions/include/ck/utility/dynamic_buffer.hpp new file mode 100755 index 000000000000..ed42b22daf9a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/dynamic_buffer.hpp @@ -0,0 +1,492 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" +#include "enable_if.hpp" +#include "c_style_pointer_cast.hpp" +#if __clang_major__ == 20 +#include "amd_buffer_addressing_builtins.hpp" +#else +#include "amd_buffer_addressing.hpp" +#endif +#include "generic_memory_space_atomic.hpp" + +namespace ck { + +// T may be scalar or vector +// X may be scalar or vector +// T and X have same scalar type +// X contains multiple T +template +struct DynamicBuffer +{ + using type = T; + + T* p_data_; + ElementSpaceSize element_space_size_; + T invalid_element_value_ = T{0}; + + // XXX: PackedSize semantics for pk_i4_t is different from the other packed types. + // Objects of f4x2_pk_t and f6_pk_t are counted as 1 element, while + // objects of pk_i4_t are counted as 2 elements. Therefore, element_space_size_ for pk_i4_t must + // be divided by 2 to correctly represent the number of addressable elements. + static constexpr index_t PackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size) + : p_data_{p_data}, element_space_size_{element_space_size} + { + } + + __host__ __device__ constexpr DynamicBuffer(T* p_data, + ElementSpaceSize element_space_size, + T invalid_element_value) + : p_data_{p_data}, + element_space_size_{element_space_size}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace() + { + return BufferAddressSpace; + } + + __host__ __device__ constexpr const T& operator[](IndexType i) const { return p_data_[i]; } + + __host__ __device__ constexpr T& operator()(IndexType i) { return p_data_[i]; } + + template >::type, + typename scalar_type>::type>::value || + !is_native_type(), + bool>::type = false> + __host__ __device__ constexpr auto Get(IndexType i, bool is_valid_element) const + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return amd_buffer_load_invalid_element_return_zero, + t_per_x, + coherence>( + p_data_, i, is_valid_element, element_space_size_ / PackedSize); + } + else + { + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x, + coherence>( + p_data_, + i, + is_valid_element, + element_space_size_ / PackedSize, + invalid_element_value_); + } + } + else + { + if(is_valid_element) + { +#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS + X tmp; + + __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X)); + + return tmp; +#else + return *c_style_pointer_cast(&p_data_[i]); +#endif + } + else + { + if constexpr(InvalidElementUseNumericalZeroValue) + { + return X{0}; + } + else + { + return X{invalid_element_value_}; + } + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void Update(IndexType i, bool is_valid_element, const X& x) + { + if constexpr(Op == InMemoryDataOperationEnum::Set) + { + this->template Set(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd) + { + this->template AtomicAdd(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax) + { + this->template AtomicMax(i, is_valid_element, x); + } + else if constexpr(Op == InMemoryDataOperationEnum::Add) + { + auto tmp = this->template Get(i, is_valid_element); + using scalar_t = typename scalar_type>::type; + // handle bfloat addition + if constexpr(is_same_v) + { + if constexpr(is_scalar_type::value) + { + // Scalar type + auto result = + type_convert(type_convert(x) + type_convert(tmp)); + this->template Set(i, is_valid_element, result); + } + else + { + // Vector type + constexpr auto vector_size = scalar_type>::vector_size; + const vector_type a_vector{tmp}; + const vector_type b_vector{x}; + static_for<0, vector_size, 1>{}([&](auto idx) { + auto result = type_convert( + type_convert(a_vector.template AsType()[idx]) + + type_convert(b_vector.template AsType()[idx])); + this->template Set(i + idx, is_valid_element, result); + }); + } + } + else + { + this->template Set(i, is_valid_element, x + tmp); + } + } + } + + template + __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf, + IndexType src_offset, + IndexType dst_offset, + bool is_valid_element) const + { + // Copy data from global to LDS memory using direct loads. + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, + "Source data must come from a global memory buffer."); + static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds, + "Destination data must be stored in an LDS memory buffer."); + + amd_direct_load_global_to_lds(p_data_, + src_offset, + dst_buf.p_data_, + dst_offset, + is_valid_element, + element_space_size_ / PackedSize); + } + + template >::type, + typename scalar_type>::type>::value || + !is_native_type(), + bool>::type = false> + __host__ __device__ void Set(IndexType i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + +#if CK_USE_AMD_BUFFER_LOAD + bool constexpr use_amd_buffer_addressing = sizeof(IndexType) <= sizeof(int32_t); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + +#if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE + bool constexpr workaround_int8_ds_write_issue = true; +#else + bool constexpr workaround_int8_ds_write_issue = false; +#endif + + if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_store, t_per_x, coherence>( + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); + } + else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds && + is_same>::type, int8_t>::value && + workaround_int8_ds_write_issue) + { + if(is_valid_element) + { + // HACK: compiler would lower IR "store address_space(3)" into inefficient + // ISA, so I try to let compiler emit IR "store" which would be lower to + // ds_write_b128 + // TODO: remove this after compiler fix + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x16_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); + + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) + { + // HACK: cast pointer of x is bad + // TODO: remove this after compiler fix + *c_style_pointer_cast(&p_data_[i]) = + *c_style_pointer_cast(&x); + } + } + } + else + { + if(is_valid_element) + { +#if 0 + X tmp = x; + + __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X)); +#else + // if(i >= 2169041600) + *c_style_pointer_cast(&p_data_[i]) = x; +#endif + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X& x) + { + using scalar_t = typename scalar_type>::type; + + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + is_same_v, int32_t> || + is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0); +#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, int32_t>; +#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && + (is_same_v, float> || + (is_same_v, half_t> && scalar_per_x_vector % 2 == 0) || + (is_same_v, bhalf_t> && scalar_per_x_vector % 2 == 0)); +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); + } + else + { + if(is_valid_element) + { + atomic_add(c_style_pointer_cast(&p_data_[i]), x); + } + } + } + + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr IndexType scalar_per_t_vector = scalar_type>::vector_size; + + constexpr IndexType scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X should contain multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 + using scalar_t = typename scalar_type>::type; + bool constexpr use_amd_buffer_addressing = + sizeof(IndexType) <= sizeof(int32_t) && is_same_v, double>; +#else + bool constexpr use_amd_buffer_addressing = false; +#endif + + if constexpr(use_amd_buffer_addressing) + { + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_max, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_ / PackedSize); + } + else if(is_valid_element) + { + atomic_max(c_style_pointer_cast(&p_data_[i]), x); + } + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } +}; + +template +__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size) +{ + return DynamicBuffer{ + p, element_space_size}; +} + +template +__host__ __device__ constexpr auto make_long_dynamic_buffer(T* p, + ElementSpaceSize element_space_size) +{ + return DynamicBuffer{ + p, element_space_size}; +} + +template < + AddressSpaceEnum BufferAddressSpace, + AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence, + typename T, + typename ElementSpaceSize, + typename X, + typename enable_if, remove_cvref_t>::value, bool>::type = false> +__host__ __device__ constexpr auto +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) +{ + return DynamicBuffer{ + p, element_space_size, invalid_element_value}; +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/e8m0.hpp b/csrc/rocm/ck_extensions/include/ck/utility/e8m0.hpp new file mode 100755 index 000000000000..f7d2a2f5948b --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/e8m0.hpp @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/type.hpp" + +namespace ck { + +/** + * @brief Unsigned representation of a conventional biased Float32 exponent. + * + * bias = 127; + * + * E8M0_1 = 0b01111111; => 2^(127-127) = 1 + * E8M0_2 = 0b10000000; => 2^(128-127) = 2^1 = 2 + * E8M0_3 = 0b10000010; => 2^(130-127) = 2^3 = 8 + * E8M0_135 = 0b10000111; => 2^(135-127) = 2^8 = 256 + * E8M0_142 = 0b10001110; => 2^(142-127) = 2^15 = 32768 + * E8M0_MIN = 0b00000000; => 2^-127 + * E8M0_MAX = 0b11111110; => 2^127 + * E8M0_NAN = 0b11111111; => NaN + */ +struct e8m0_bexp_t +{ + using type = uint8_t; + type data; + + constexpr static type bias = 127; + constexpr static type nan_mask = 0xFF; + + __host__ __device__ constexpr e8m0_bexp_t() : data{type{}} {} + __host__ __device__ constexpr e8m0_bexp_t(type init) : data{init} {} + __host__ __device__ constexpr e8m0_bexp_t(int init) : data{static_cast(init & nan_mask)} + { + } + __host__ __device__ explicit constexpr e8m0_bexp_t(float scale) + : data{static_cast((bit_cast(scale) & (nan_mask << 23)) >> 23)} + { + } + + __host__ __device__ explicit constexpr operator float() const + { + if(data == nan_mask || data == 0) + { + uint32_t bits = data << 1; + bits |= 1; + bits <<= 22; + return bit_cast(bits); + } + else + { + uint32_t bits = data << 23; + return bit_cast(bits); + } + } + + __host__ __device__ constexpr bool operator==(const e8m0_bexp_t& other) const + { + // strict IEEE compliance for NaN + return data == other.data && data != nan_mask; + } + + __host__ __device__ constexpr bool is_nan() const { return data == nan_mask; } +}; + +namespace utils { + +template +__host__ __device__ inline constexpr int32_t get_exponent_value(T x); + +template <> +__host__ __device__ inline constexpr int32_t get_exponent_value(e8m0_bexp_t x) +{ + return x.data; +} + +} // namespace utils + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/enable_if.hpp b/csrc/rocm/ck_extensions/include/ck/utility/enable_if.hpp new file mode 100755 index 000000000000..9d5403ceb209 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/enable_if.hpp @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) +template +struct enable_if +{ +}; + +template +struct enable_if +{ + using type = T; +}; + +template +using enable_if_t = typename enable_if::type; + +#else +template +using enable_if = std::enable_if; + +template +using enable_if_t = typename std::enable_if::type; +#endif +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/env.hpp b/csrc/rocm/ck_extensions/include/ck/utility/env.hpp new file mode 100755 index 000000000000..2f5b804d16bf --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/env.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#ifndef CK_CODE_GEN_RTC + +#include +#include +#include +#include +#include + +namespace ck { +namespace internal { +template +struct ParseEnvVal +{ +}; + +template <> +struct ParseEnvVal +{ + static bool parse_env_var_value(const char* vp) + { + std::string value_env_str{vp}; + + for(auto& c : value_env_str) + { + if(std::isalpha(c) != 0) + { + c = std::tolower(static_cast(c)); + } + } + + if(value_env_str == "disable" || value_env_str == "disabled" || value_env_str == "0" || + value_env_str == "no" || value_env_str == "off" || value_env_str == "false") + { + return false; + } + else if(value_env_str == "enable" || value_env_str == "enabled" || value_env_str == "1" || + value_env_str == "yes" || value_env_str == "on" || value_env_str == "true") + { + return true; + } + else + { + throw std::runtime_error("Invalid value for env variable"); + } + + return false; // shouldn't reach here + } +}; + +// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default). +// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string). +template <> +struct ParseEnvVal +{ + static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); } +}; + +template <> +struct ParseEnvVal +{ + static std::string parse_env_var_value(const char* vp) { return std::string{vp}; } +}; + +template +struct EnvVar +{ + private: + T value{}; + bool is_unset = true; + + public: + const T& GetValue() const { return value; } + + bool IsUnset() const { return is_unset; } + + void Unset() { is_unset = true; } + + void UpdateValue(const T& val) + { + is_unset = false; + value = val; + } + + explicit EnvVar(const char* const name, const T& def_val) + { + // NOLINTNEXTLINE (concurrency-mt-unsafe) + const char* vp = std::getenv(name); + if(vp != nullptr) // a value was provided + { + is_unset = false; + value = ParseEnvVal::parse_env_var_value(vp); + } + else // no value provided, use default value + { + value = def_val; + } + } +}; +} // end namespace internal + +// static inside function hides the variable and provides +// thread-safety/locking +// Used in global namespace +#define CK_DECLARE_ENV_VAR(name, type, default_val) \ + namespace ck::env { \ + struct name \ + { \ + static_assert(std::is_same_v, \ + "CK_DECLARE_ENV* must be used in the global namespace"); \ + using value_type = type; \ + static ck::internal::EnvVar& Ref() \ + { \ + static ck::internal::EnvVar var{#name, default_val}; \ + return var; \ + } \ + }; \ + } + +#define CK_DECLARE_ENV_VAR_BOOL(name) CK_DECLARE_ENV_VAR(name, bool, false) + +#define CK_DECLARE_ENV_VAR_UINT64(name) CK_DECLARE_ENV_VAR(name, uint64_t, 0) + +#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "") + +#define CK_ENV(name) \ + ck::env::name {} + +template +inline const std::string& EnvGetString(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsEnabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsDisabled(EnvVar) +{ + static_assert(std::is_same_v); + return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue(); +} + +template +inline uint64_t EnvValue(EnvVar) +{ + static_assert(std::is_same_v); + return EnvVar::Ref().GetValue(); +} + +template +inline bool EnvIsUnset(EnvVar) +{ + return EnvVar::Ref().IsUnset(); +} + +template +void EnvUnset(EnvVar) +{ + EnvVar::Ref().Unset(); +} + +/// updates the cached value of an environment variable +template +void UpdateEnvVar(EnvVar, const ValueType& val) +{ + static_assert(std::is_same_v); + EnvVar::Ref().UpdateValue(val); +} + +template +void UpdateEnvVar(EnvVar, const std::string_view& val) +{ + EnvVar::Ref().UpdateValue( + ck::internal::ParseEnvVal::parse_env_var_value(val.data())); +} + +} // namespace ck + +// environment variable to enable logging: +// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED +CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) + +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/f8_utils.hpp b/csrc/rocm/ck_extensions/include/ck/utility/f8_utils.hpp new file mode 100755 index 000000000000..799683ae6555 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/f8_utils.hpp @@ -0,0 +1,293 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/numeric_utils.hpp" + +namespace ck { + +// fp8 rounding modes +// use standard for rounding to nearest, the faster one +// use stochastic for stochastic rounding, helps to avoid error accumulation +enum class f8_rounding_mode +{ + standard, + stochastic +}; + +__host__ inline int clz(uint32_t x) { return __builtin_clz(x); } +__device__ inline int clz(uint32_t x) { return __clz(x); } + +} // namespace ck + +namespace ck::utils { + +namespace { + +template +__host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; + + // original type exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; + + int exponent, bias; + uint32_t head, mantissa, sign; + // nan code is same for float and half + constexpr Y nan_code = 0x80; + constexpr uint32_t nan_mask = NumericUtils::nan_mask; + + // convert to bitwise + using T_bitwise = typename NumericUtils::bitwise_type; + T_bitwise x_bitwise = bit_cast(x); + + // unpack the input, depends on datatype + head = x_bitwise & NumericUtils::head_mask; + mantissa = x_bitwise & NumericUtils::mant_mask; + exponent = (head >> in_mant) & NumericUtils::exp_mask; + sign = head >> (in_exp + in_mant); + bias = NumericUtils::bias; + + uint32_t signed_inf = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant); + uint32_t drop_mask = (1 << (in_mant - out_mant)) - 1; + constexpr int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2); + + if constexpr(negative_zero_nan) + { + if((x_bitwise & nan_mask) == nan_mask) + return nan_code; + } + else + { + if((x_bitwise & nan_mask) == nan_mask) + return signed_inf + (mantissa != 0 ? 1 : 0); + } + + // check if x is 0.0 + if(x_bitwise == 0) + return 0; + + // First need to check if it is normal or denorm as there is a difference of implict 1 + // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift + // The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for + // RNE, no need to add rng. Then probably need to check whether there is carry and adjust + // exponent and mantissa again3 + + // For IEEE bias mode, the bias is 2^(k-1)-1 where k is the width of exponent bits + const int out_bias = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int out_denormal_act_exponent = 1 - out_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // out_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, out_exponent, exponent_diff; + + if(exponent == 0) + { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16 +here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal has +exponent bias 15 while bf8 with NANOO has exponent bias 16. It means that there are some numbers in +fp16 denormal but they are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 (NANOO) normal. +In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = out_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } + else + { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if(act_exponent <= out_denormal_act_exponent) + { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal range. + For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16 + actual exponent is -7, it is actually larger due to the implict 1, + Therefore it needs to be adjust to -6 and mantissa shift right by 1. + So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = out_denormal_act_exponent - act_exponent; + } + else + { // both fp32/fp16 and f8 are in normal range + exponent_diff = + 0; // exponent_diff=0 does not mean there is no difference for this case, + // act_exponent could be larger. Just that it does not need shift mantissa + } + mantissa += (1 << in_mant); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) == + (1 << (in_mant - out_mant + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we + shift right as shift right could rip off some residual part and make something not midpoint look + like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than + midpoint, but after shift right by 4 bits, it would look like midpoint. */ + + if(exponent_diff > 0) + mantissa >>= exponent_diff; + else if(exponent_diff == -1) + mantissa <<= -exponent_diff; + bool implicit_one = mantissa & (1 << in_mant); + // if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent + out_exponent = + (act_exponent + exponent_diff) /*actual f8 exponent*/ + out_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + bool odd = + mantissa & + (1 << (in_mant - out_mant)); // if the least significant bit that is not truncated is 1 + mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask; + + // Now we deal with overflow + if(out_exponent == 0) + { + if((1 << in_mant) & mantissa) + { + out_exponent = 1; // denormal overflow to become normal, promote exponent + // No need to make 1 implicit now as it will be addressed later + } + } + else + { + if((1 << (in_mant + 1)) & mantissa) + { + mantissa >>= 1; + out_exponent++; + // No need to make 1 implicit now as it will be addressed later + } + } + + mantissa >>= (in_mant - out_mant); + + if(out_exponent > max_exp) + { + if constexpr(clip) + { + mantissa = (1 << out_mant) - 1; + out_exponent = max_exp; + } + else + { + return signed_inf; + } + } + + // check if x is 0.0 or -0.0 + if(out_exponent == 0 && mantissa == 0) + return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + mantissa &= (1 << out_mant) - 1; + return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; +} + +template +__host__ __device__ Y run_cast_from_f8(X x) +{ + // fp8/bf8 exponent/mantissa layout + constexpr int in_exp = NumericUtils::exp; + constexpr int in_mant = NumericUtils::mant; + + // resulting type exponent/mantissa layout + constexpr int out_exp = NumericUtils::exp; + constexpr int out_mant = NumericUtils::mant; + + // prepare the codes + constexpr X nan_code = 0x80; + using T_bitwise = typename NumericUtils::bitwise_type; + + constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; + constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; + constexpr T_bitwise NaN_bitwise = NumericUtils::NaN; + constexpr T_bitwise Neg0_bitwise = NumericUtils::Neg0; + + constexpr Y Inf = bit_cast(Inf_bitwise); + constexpr Y NegInf = bit_cast(NegInf_bitwise); + constexpr Y NaN = bit_cast(NaN_bitwise); + constexpr Y Neg0 = bit_cast(Neg0_bitwise); + + // check if x is 0.0 + if(x == 0) + return static_cast(0); + + // unpack the input + uint32_t sign = x >> (in_exp + in_mant); + uint32_t mantissa = x & ((1 << in_mant) - 1); + int exponent = (x & 0x7F) >> in_mant; + + constexpr int exp_low_cutoff = + (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); + T_bitwise retval; + + if constexpr(negative_zero_nan) + { + if(x == nan_code) + return NaN; + } + else + { + if(x == nan_code) + return Neg0; + if(exponent == ((1 << in_exp) - 1)) + return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; + } + + if constexpr((NumericUtils::mant == 10) && (NumericUtils::mant == 2) && + !negative_zero_nan) + { + retval = x; + retval <<= 8; + return bit_cast(retval); + } + + // subnormal input + if(exponent == 0) + { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - in_mant); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << in_mant) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= out_mant - in_mant; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if(exponent <= 0) + { + mantissa |= 1 << out_mant; + mantissa >>= 1 - exponent; + exponent = 0; + } + + retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa; + return bit_cast(retval); +} + +} // namespace + +template +__host__ __device__ Y cast_to_f8(X x, uint32_t rng) +{ + // check datatypes + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "Only half and float can be casted."); + + return run_cast_to_f8(x, rng); +} + +template +__host__ __device__ Y cast_from_f8(X x) +{ + // check datatype + constexpr bool is_half = std::is_same::value; + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported."); + + return run_cast_from_f8(x); +} + +} // namespace ck::utils diff --git a/csrc/rocm/ck_extensions/include/ck/utility/filter_tuple.hpp b/csrc/rocm/ck_extensions/include/ck/utility/filter_tuple.hpp new file mode 100755 index 000000000000..c2e378b879df --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/filter_tuple.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck::util { + +template +struct filter_tuple_by_modulo +{ + // Validate Stride and Offset. + static_assert(Stride > 0, "Offset must be positive."); + static_assert(Offset >= 0 && Offset < Stride, + "Offset must be positive and less than the stride."); + + // Generate filtered indices for this stride and offset. + static constexpr int new_size = (std::tuple_size_v + Stride - Offset - 1) / Stride; + + template + static constexpr auto to_index(std::index_sequence) + { + return std::index_sequence<(Offset + Is * Stride)...>{}; + } + + using filtered_indices = decltype(to_index(std::make_index_sequence{})); + + // Helper struct to construct the new tuple type from the filtered indices. + template + struct make_filtered_tuple_type_impl; + + template + struct make_filtered_tuple_type_impl> + { + using type = std::tuple...>; + }; + + using type = typename make_filtered_tuple_type_impl::type; +}; + +// Filter a tuple with a stride and offset. +// +// Tuple is a std::tuple or equivalent +// Stride is a positive integer +// Offset is a positive integer smaller than ofset +// +// Evaluates to a smaller tuple type from elements of T with stride M and offset I. +// +// Can be used to filter a tuple of types for sharded instantiations. +template +using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo::type; + +// Example compile-time test: +// using OriginalTuple = +// std::tuple; +// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t; +// static_assert(std::is_same_v>, +// "Test Case 1 Failed: Every 3rd from 2nd"); + +} // namespace ck::util diff --git a/csrc/rocm/ck_extensions/include/ck/utility/flush_icache.hpp b/csrc/rocm/ck_extensions/include/ck/utility/flush_icache.hpp new file mode 100755 index 000000000000..7378ba5c26dc --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/flush_icache.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck { +static __global__ void flush_icache() +{ + asm __volatile__("s_icache_inv \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" + "s_nop 0 \n\t" :: + :); +} +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/functional.hpp b/csrc/rocm/ck_extensions/include/ck/utility/functional.hpp new file mode 100755 index 000000000000..cd48ed174744 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/functional.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/integral_constant.hpp" +#include "ck/utility/type.hpp" + +namespace ck { + +// TODO: right? wrong? +struct forwarder +{ + template + __host__ __device__ constexpr T&& operator()(T&& x) const + { + return static_cast(x); + } +}; + +struct swallow +{ + template + __host__ __device__ constexpr swallow(Ts&&...) + { + } +}; + +template +struct logical_and +{ + constexpr bool operator()(const T& x, const T& y) const { return x && y; } +}; + +template +struct logical_or +{ + constexpr bool operator()(const T& x, const T& y) const { return x || y; } +}; + +template +struct logical_not +{ + constexpr bool operator()(const T& x) const { return !x; } +}; + +// Emulate if constexpr +template +struct static_if; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F f) const + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + return Type{}; + } + + template + __host__ __device__ static void Else(F) + { + } +}; + +template <> +struct static_if +{ + using Type = static_if; + + template + __host__ __device__ constexpr auto operator()(F) const + { + return Type{}; + } + + template + __host__ __device__ static void Else(F f) + { + // This is a trick for compiler: + // Pass forwarder to lambda "f" as "auto" argument, and make sure "f" will + // use it, + // this will make "f" a generic lambda, so that "f" won't be compiled + // until being + // instantiated here + f(forwarder{}); + } +}; + +template +struct conditional; + +template +struct conditional +{ + using type = X; +}; + +template +struct conditional +{ + using type = Y; +}; + +template +using conditional_t = typename conditional::type; + +// z = predicate ? x : y +template +constexpr auto conditional_expr(X&& x, Y&& y) +{ + if constexpr(predicate) + { + return ck::forward(x); + } + else + { + return ck::forward(y); + } +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/functional2.hpp b/csrc/rocm/ck_extensions/include/ck/utility/functional2.hpp new file mode 100755 index 000000000000..ef8b5a435c99 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/functional2.hpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/type.hpp" + +namespace ck { + +namespace detail { + +template +struct static_for_impl; + +template +struct static_for_impl> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + swallow{(f(Number{}), 0)...}; + } +}; + +} // namespace detail + +// F signature: F(Number) +template +struct static_for +{ + __host__ __device__ constexpr static_for() + { + static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0, + "Wrong! should satisfy (NEnd - NBegin) % Increment == 0"); + static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd), + "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && " + "NBegin >= NEnd)"); + } + + template + __host__ __device__ constexpr void operator()(F f) const + { + detail::static_for_impl::type>{}( + f); + } +}; + +namespace detail { + +template +struct applier +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + // tweak -fbracket-depth if compilation fails. Clang default limit is 256 + (f(Number{}), ...); + } +}; + +template // == sizeof...(Is) +using make_applier = __make_integer_seq; + +} // namespace detail + +template +struct static_for<0, N, 1> : detail::make_applier +{ + using detail::make_applier::operator(); +}; + +template +struct static_for_range +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + // tweak -fbracket-depth if compilation fails. Clang default limit is 256 + (f(Is{}), ...); + } +}; + +template +struct static_for_product; +template +struct static_for_product> : public static_for_range +{ +}; +template +struct static_for_product, Rest...> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + static_for_product>{}([&](auto i0) { // + static_for_product{}([&](auto... is) { // + f(i0, is...); + }); + }); + } +}; + +struct identity +{ + template + __host__ __device__ constexpr T&& operator()(T&& arg) const noexcept + { + return ck::forward(arg); + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/functional3.hpp b/csrc/rocm/ck_extensions/include/ck/utility/functional3.hpp new file mode 100755 index 000000000000..97605a7adeb8 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/functional3.hpp @@ -0,0 +1,144 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" +#include "ck/utility/functional.hpp" +#include "ck/utility/functional2.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/multi_index.hpp" + +namespace ck { + +namespace detail { + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct static_ford_impl +{ + __host__ __device__ constexpr static_ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Sequence<...>) + // CurrentOrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId) const + { + static_for<0, RemainLengths::Front(), 1>{}([=](auto I) { + static_ford_impl{}( + f, CurrentOrderedId::PushBack(I)); + }); + } +}; + +template +struct static_ford_impl, Orders> +{ + // F signature: F(Sequence<...>) + // OrderedId: Sequence<...> + template + __host__ __device__ constexpr void operator()(F f, OrderedId) const + { + // retrive unordered Id + f(OrderedId::ReorderGivenOld2New(Orders{})); + } +}; + +// RemainLengths: Sequence<...> +// Orders: Sequence<...> +template +struct ford_impl +{ + __host__ __device__ constexpr ford_impl() + { + static_assert(RemainLengths::GetSize() > 0, "wrong! should not get here"); + } + + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + for(index_t i = 0; i < RemainLengths::Front(); ++i) + { + ford_impl{}( + f, container_push_back(current_ordered_id, i)); + } + } +}; + +template +struct ford_impl, Orders> +{ + // F signature: F(Array<...> multi_id) + // CurrentOrderdId: Array<...> + template + __host__ __device__ constexpr void operator()(F f, CurrentOrderedId current_ordered_id) const + { + // retrive unordered Id + f(container_reorder_given_old2new(current_ordered_id, Orders{})); + } +}; + +} // namespace detail + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which static_ford +// will loop over each +// dimension +template ::type> +struct static_ford +{ + __host__ __device__ constexpr static_ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Sequence<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + detail::static_ford_impl{}(f, Sequence<>{}); + } +}; + +// Lengths is Sequence<...>, it is the length of each dimension for +// N-dimensional loop +// Orders is Sequence<...>, it is the order of dimension in which ford will loop +// over each +// dimension +template ::type> +struct ford +{ + __host__ __device__ constexpr ford() + { + static_assert(Lengths::GetSize() > 0, "wrong! Lengths is empty"); + static_assert(Lengths::GetSize() == Orders::GetSize(), "wrong! inconsistent size"); + } + + // F signature: F(Array<...> multi_id) + // multi_id is the unordered multi-index + template + __host__ __device__ constexpr void operator()(F f) const + { + constexpr auto ordered_lengths = Lengths::ReorderGivenNew2Old(Orders{}); + + for(index_t i = 0; i < ordered_lengths.Front(); ++i) + { + detail::ford_impl{}(f, + make_multi_index(i)); + } + } +}; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/functional4.hpp b/csrc/rocm/ck_extensions/include/ck/utility/functional4.hpp new file mode 100755 index 000000000000..8e86a296dc2e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/functional4.hpp @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#ifndef CK_FUNCTIONAL4_HPP +#define CK_FUNCTIONAL4_HPP + +#include "sequence.hpp" +#include "tuple.hpp" +#include "array.hpp" + +namespace ck { + +namespace detail { + +template +struct unpack_impl; + +template +struct unpack_impl> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x) const + { + return ck::forward(f)(ck::forward(x).At(Number{})...); + } +}; + +template +struct unpack2_impl; + +// TODO: remove this, after properly implementing unpack that takes any number of containers +template +struct unpack2_impl, Sequence> +{ + template + __host__ __device__ constexpr auto operator()(F&& f, X&& x, Y&& y) const + { + return ck::forward(f)(ck::forward(x).At(Number{})..., + ck::forward(y).At(Number{})...); + } +}; + +} // namespace detail + +template +__host__ __device__ constexpr auto unpack(F&& f, X&& x) +{ + using X_ = remove_reference_t; + return detail::unpack_impl::type>{}( + ck::forward(f), ck::forward(x)); +} + +// TODO: properly implement unpack that takes any number of containers +template +__host__ __device__ constexpr auto unpack2(F&& f, X&& x, Y&& y) +{ + using X_ = remove_reference_t; + using Y_ = remove_reference_t; + return detail::unpack2_impl::type, + typename arithmetic_sequence_gen<0, Y_::Size(), 1>::type>{}( + ck::forward(f), ck::forward(x), ck::forward(y)); +} + +} // namespace ck +#endif diff --git a/csrc/rocm/ck_extensions/include/ck/utility/generic_memory_space_atomic.hpp b/csrc/rocm/ck_extensions/include/ck/utility/generic_memory_space_atomic.hpp new file mode 100755 index 000000000000..011491ffc63e --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/generic_memory_space_atomic.hpp @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" +#include "dtype_fp64.hpp" + +namespace ck { + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_add explicit for +// each datatype. +template +__device__ X atomic_add(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_add(int32_t* p_dst, const int32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ uint32_t atomic_add(uint32_t* p_dst, const uint32_t& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float atomic_add(float* p_dst, const float& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ unsigned short atomic_add(unsigned short* p_dst, const unsigned short& x) +{ + // Use atomicAdd with unsigned int + return static_cast( + atomicAdd(reinterpret_cast(p_dst), static_cast(x))); +} + +template <> +__device__ _Float16 atomic_add<_Float16>(_Float16* p_dst, const _Float16& x) +{ + // Use atomicAdd with unsigned int + return static_cast<_Float16>( + atomicAdd(reinterpret_cast(p_dst), static_cast(x))); +} + +template <> +__device__ double atomic_add(double* p_dst, const double& x) +{ + return atomicAdd(p_dst, x); +} + +template <> +__device__ float2_t atomic_add(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +template <> +__device__ double2_t atomic_add(double2_t* p_dst, const double2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicAdd(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicAdd(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +// Caution: DO NOT REMOVE +// intentionally have only declaration but no definition to cause compilation failure when trying to +// instantiate this template. The purpose is to make the implementation of atomic_max explicit for +// each datatype. + +template +__device__ X atomic_max(X* p_dst, const X& x); + +template <> +__device__ int32_t atomic_max(int32_t* p_dst, const int32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ uint32_t atomic_max(uint32_t* p_dst, const uint32_t& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float atomic_max(float* p_dst, const float& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ double atomic_max(double* p_dst, const double& x) +{ + return atomicMax(p_dst, x); +} + +template <> +__device__ float2_t atomic_max(float2_t* p_dst, const float2_t& x) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + const vector_type vx{x}; + vector_type vy{0}; + + vy.template AsType()(I0) = + atomicMax(c_style_pointer_cast(p_dst), vx.template AsType()[I0]); + vy.template AsType()(I1) = + atomicMax(c_style_pointer_cast(p_dst) + 1, vx.template AsType()[I1]); + + return vy.template AsType()[I0]; +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/get_id.hpp b/csrc/rocm/ck_extensions/include/ck/utility/get_id.hpp new file mode 100755 index 000000000000..fd0d1024b25a --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/get_id.hpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/ck.hpp" + +namespace ck { + +__host__ __device__ constexpr index_t get_warp_size() +{ +#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) + return 64; +#else + return 32; +#endif +} + +__device__ index_t get_thread_local_1d_id() { return threadIdx.x; } + +__device__ index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; } + +__device__ index_t get_warp_local_1d_id() { return threadIdx.x / get_warp_size(); } + +__device__ index_t get_block_1d_id() { return blockIdx.x; } + +__device__ index_t get_grid_size() { return gridDim.x; } + +__device__ index_t get_block_size() { return blockDim.x; } + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/get_shift.hpp b/csrc/rocm/ck_extensions/include/ck/utility/get_shift.hpp new file mode 100755 index 000000000000..0a93081cfd01 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/get_shift.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +static constexpr __device__ index_t get_shift() +{ + return (get_shift() + 1); +}; + +template <> +constexpr __device__ index_t get_shift<1>() +{ + return (0); +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/ignore.hpp b/csrc/rocm/ck_extensions/include/ck/utility/ignore.hpp new file mode 100755 index 000000000000..f70a182fd4e5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/ignore.hpp @@ -0,0 +1,22 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// https://en.cppreference.com/w/cpp/utility/tuple/ignore + +namespace ck { + +namespace detail { +struct ignore_t +{ + template + constexpr void operator=(T&&) const noexcept + { + } +}; +} // namespace detail + +inline constexpr detail::ignore_t ignore; + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/inner_product.hpp b/csrc/rocm/ck_extensions/include/ck/utility/inner_product.hpp new file mode 100755 index 000000000000..65efaf388ae5 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/inner_product.hpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include "data_type.hpp" +#include "type_convert.hpp" + +namespace ck { + +template +__device__ void inner_product(const TA& a, const TB& b, TC& c); + +template <> +__device__ void inner_product(const float& a, const float& b, float& c) +{ +#if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32) + asm volatile("\n \ + v_mac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32) + asm volatile("\n \ + v_fmac_f32 %0, %1, %2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c += a * b; +#endif +} + +template <> +__device__ void +inner_product(const float2_t& a, const float2_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const float4_t& a, const float4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void inner_product(const bhalf_t& a, const bhalf_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + +template <> +__device__ void inner_product(const half_t& a, const half_t& b, float& c) +{ + inner_product(type_convert(a), type_convert(b), c); +} + +template <> +__device__ void inner_product(const half2_t& a, const half2_t& b, float& c) +{ +#if defined(CK_USE_AMD_V_DOT2_F32_F16) +#if CK_USE_AMD_V_DOT_INLINE_ASM + // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47 + // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf + // ) s_nop with parameter 2 is equal to 3 x s_nop + asm volatile("\n \ + v_dot2_f32_f16 %0, %1, %2, %0\n \ + s_nop 2 \n \ + " + : "=v"(c) + : "v"(a), "v"(b), "0"(c)); +#else + c = __builtin_amdgcn_fdot2(a, b, c, false); +#endif +#else + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 2, 1>{}([&](auto i) { + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void inner_product(const half4_t& a, const half4_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void inner_product(const half8_t& a, const half8_t& b, float& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +template <> +__device__ void inner_product(const int8_t& a, const int8_t& b, int32_t& c) +{ + c += type_convert(a) * type_convert(b); +} + +template <> +__device__ void +inner_product(const int8x2_t& a, const int8x2_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c) +{ +#if defined(CK_USE_AMD_V_DOT4_I32_I8) +#if CK_USE_AMD_V_DOT_INLINE_ASM + // Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47 + // https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf + // ) s_nop with parameter 2 is equal to 3 x s_nop + asm volatile("\n \ + v_dot4_i32_i8 %0, %1, %2, %0\n \ + s_nop 2 \n \ + " + : "=v"(c) + : "v"(bit_cast(a)), "v"(bit_cast(b)), "0"(c)); +#else + c = __builtin_amdgcn_sdot4(bit_cast(a), bit_cast(b), c, false); +#endif +#elif defined(CK_USE_AMD_V_DOT4_I32_I8_GFX11) + c = __builtin_amdgcn_sudot4(true, bit_cast(a), true, bit_cast(b), c, false); +#else + const vector_type a_vector{a}; + const vector_type b_vector{b}; + + static_for<0, 4, 1>{}([&](auto i) { + c += type_convert(a_vector.AsType()[i]) * + type_convert(b_vector.AsType()[i]); + }); +#endif +} + +template <> +__device__ void +inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); +} + +template <> +__device__ void +inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + inner_product(vector_type{a}.AsType()[I0], + vector_type{b}.AsType()[I0], + c); + + inner_product(vector_type{a}.AsType()[I1], + vector_type{b}.AsType()[I1], + c); + + inner_product(vector_type{a}.AsType()[I2], + vector_type{b}.AsType()[I2], + c); + + inner_product(vector_type{a}.AsType()[I3], + vector_type{b}.AsType()[I3], + c); +} + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/inner_product_dpp8.hpp b/csrc/rocm/ck_extensions/include/ck/utility/inner_product_dpp8.hpp new file mode 100755 index 000000000000..f079e2ca6490 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/inner_product_dpp8.hpp @@ -0,0 +1,146 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "amd_gemm_dpp.hpp" +#include "data_type.hpp" +#include "type_convert.hpp" + +namespace ck { + +namespace dpp8 { + +/// Number of lanes that can share data using DPP8 modifiers. +constexpr index_t lane_group_size = 8; + +template +__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c); + +// clang-format off +template <> +__device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +template <> +__device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){ + asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c)); +} +// clang-format on + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted as inline assembly. + */ +template +__device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c) +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + if constexpr(ShareA) + { + inline_v_dot2c_dpp8_instr(a, b, c); + } + else + { + inline_v_dot2c_dpp8_instr(b, a, c); + } +} + +/** + * DPP8 instrinsics expects to get an integer mask, hardcoding integers for specific broadcast + * patters. + */ +constexpr std::array IntrinsicMaskDpp8 = { + 0, // 0, 0, 0, 0, 0, 0, 0, 0 + 2396745, // 1, 1, 1, 1, 1, 1, 1, 1 + 4793490, // 2, 2, 2, 2, 2, 2, 2, 2 + 7190235, // 3, 3, 3, 3, 3, 3, 3, 3 + 9586980, // 4, 4, 4, 4, 4, 4, 4, 4 + 11983725, // 5, 5, 5, 5, 5, 5, 5, 5 + 14380470, // 6, 6, 6, 6, 6, 6, 6, 6 + 16777215, // 7, 7, 7, 7, 7, 7, 7, 7 +}; + +/** + * Returns DPP8 sel modifier as an integer required for the intrinsic instruction. + */ +template +constexpr int get_dpp_sel_mask_broadcast() +{ + static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size, + "DPP8 src broadcast lane out of range <0, 7>."); + return IntrinsicMaskDpp8[SrcLaneIdx]; +} + +template +__device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c) +{ + constexpr int sel_mask = get_dpp_sel_mask_broadcast(); + const half2_t val_from_other_lane = + bit_cast(__builtin_amdgcn_mov_dpp8(bit_cast(a), sel_mask)); + c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false); +} + +/** + * Dot product of two vectors using `v_dot` instruction with DPP8 submitted using intrinsics. + */ +template +__device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c) +{ + if constexpr(ShareA) + { + intrinsic_fdot2_impl(a, b, c); + } + else + { + intrinsic_fdot2_impl(b, a, c); + } +} + +/** + * Dot product of two input vectors `a`, `b` using `v_dot` instructions with DPP modifier. + * + * DPP modifier allows us to share one of the vectors between lanes in a lane group. + * When `ShareA` is set, instruction uses vector `a` from lane `SrcLaneIdx` from the same + * lane group (8 lanes per lane group in DPP8). When `ShareA` is not set, vector `b` is shared. + * Note that all the threads in a lane group uses the same vector - broadcast pattern. + * + * `SrcLaneIdx` must be in range from 0 to 7. + */ +template +__device__ void inner_product_dpp(const TA& a, const TB& b, TC& c) +{ +#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM + inline_v_dot2c_dpp8(a, b, c); +#else + intrinsic_fdot2(a, b, c); +#endif +} + +} // namespace dpp8 + +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/integral_constant.hpp b/csrc/rocm/ck_extensions/include/ck/utility/integral_constant.hpp new file mode 100755 index 000000000000..a7fa64d71029 --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/integral_constant.hpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +namespace ck { + +template +struct constant +{ + using value_type = decltype(v); + using type = constant; // using injected-class-name + static constexpr value_type value = v; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + +template +struct integral_constant : constant +{ + static constexpr T value = v; + typedef T value_type; + typedef integral_constant type; +}; + +template +__host__ __device__ constexpr auto operator+(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator-(integral_constant, integral_constant) +{ + static_assert(Y <= X, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator*(integral_constant, integral_constant) +{ + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator/(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +template +__host__ __device__ constexpr auto operator%(integral_constant, integral_constant) +{ + static_assert(Y > 0, "wrong!"); + return integral_constant{}; +} + +template +using bool_constant = integral_constant; + +using true_type = bool_constant; +using false_type = bool_constant; +} // namespace ck diff --git a/csrc/rocm/ck_extensions/include/ck/utility/is_detected.hpp b/csrc/rocm/ck_extensions/include/ck/utility/is_detected.hpp new file mode 100755 index 000000000000..a700fcfff1dd --- /dev/null +++ b/csrc/rocm/ck_extensions/include/ck/utility/is_detected.hpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/integral_constant.hpp" + +namespace ck { + +namespace detail { +template class Op, class... Args> +struct detector +{ + using value_t = integral_constant; + using type = Default; +}; + +template class Op, class... Args> +struct detector>, Op, Args...> +{ + using value_t = integral_constant; + using type = Op; +}; +} // namespace detail + +struct nonesuch +{ + ~nonesuch() = delete; + nonesuch(nonesuch const&) = delete; + void operator=(nonesuch const&) = delete; +}; + +template